From cf81c2b05edd54f9b779d221ad0705f8faabc2e9 Mon Sep 17 00:00:00 2001 From: Vatsal Shah Date: Mon, 22 Dec 2025 16:01:04 +0530 Subject: [PATCH] feat/ add local test --- environment_entrypoint.py | 96 ++++++++++++++++++----- validator/conda.py | 24 ++++-- validator/config.py | 16 ++++ validator/modules/rl/__init__.py | 26 +++++-- validator/validation_runner.py | 129 ++++++++++++++++++++++++++++--- 5 files changed, 249 insertions(+), 42 deletions(-) diff --git a/environment_entrypoint.py b/environment_entrypoint.py index 7c0ee48..aa37353 100644 --- a/environment_entrypoint.py +++ b/environment_entrypoint.py @@ -11,33 +11,91 @@ @click.option( "--task_ids", type=str, - required=True, - help="The ids of the task, separated by comma", + required=False, + help="The ids of the task, separated by comma (not required for local test)", ) -@click.option('--flock-api-key', envvar='FLOCK_API_KEY', required=True, help='Flock API key') -@click.option('--hf-token', envvar='HF_TOKEN', required=True, help='HuggingFace token') +@click.option('--flock-api-key', envvar='FLOCK_API_KEY', required=False, help='Flock API key (not required for local test)') +@click.option('--hf-token', envvar='HF_TOKEN', required=False, help='HuggingFace token') @click.option('--time-sleep', envvar='TIME_SLEEP', default=60 * 3, type=int, show_default=True, help='Time to sleep between retries (seconds)') @click.option('--assignment-lookup-interval', envvar='ASSIGNMENT_LOOKUP_INTERVAL', default=60 * 3, type=int, show_default=True, help='Assignment lookup interval (seconds)') @click.option("--debug", is_flag=True) -def main(module: str, task_ids: str, flock_api_key: str, hf_token: str, time_sleep: int, assignment_lookup_interval: int, debug: bool): +# Local test options +@click.option('--local-test', is_flag=True, help='Run a local validation test without submitting to FedLedger') +@click.option('--hf-repo', type=str, help='HuggingFace repository ID for local test (e.g., username/model-name)') +@click.option('--revision', type=str, default='main', help='Git revision/commit ID for local test (default: main)') +@click.option('--validation-set', type=str, help='Path to local validation set file for local test') +@click.option('--model-filename', type=str, default='model.onnx', help='Model filename in repo for local test (default: model.onnx)') +@click.option('--max-params', type=int, default=100000000, help='Maximum model parameters for local test (default: 100M)') +def main( + module: str, + task_ids: str, + flock_api_key: str, + hf_token: str, + time_sleep: int, + assignment_lookup_interval: int, + debug: bool, + local_test: bool, + hf_repo: str, + revision: str, + validation_set: str, + model_filename: str, + max_params: int, +): """ CLI entrypoint for running the validation process. Delegates core logic to ValidationRunner. + + For local testing, use --local-test flag with --hf-repo and --validation-set options. + Example: python environment_entrypoint.py rl --local-test --hf-repo username/model --validation-set data.npz """ - runner = ValidationRunner( - module=module, - task_ids=task_ids.split(","), - flock_api_key=flock_api_key, - hf_token=hf_token, - time_sleep=time_sleep, - assignment_lookup_interval=assignment_lookup_interval, - debug=debug, - ) - try: - runner.run() - except KeyboardInterrupt: - click.echo("\nValidation interrupted by user.") - sys.exit(0) + if local_test: + # Local test mode - no API required + if not hf_repo: + click.echo("Error: --hf-repo is required for local test mode", err=True) + sys.exit(1) + + runner = ValidationRunner( + module=module, + local_test=True, + debug=debug, + ) + try: + runner.run_local_test( + hf_repo=hf_repo, + revision=revision, + validation_set_path=validation_set, + model_filename=model_filename, + max_params=max_params, + ) + except KeyboardInterrupt: + click.echo("\nLocal test interrupted by user.") + sys.exit(0) + except Exception as e: + click.echo(f"\nLocal test failed: {e}", err=True) + sys.exit(1) + else: + # Normal validation mode - requires API + if not task_ids: + click.echo("Error: --task_ids is required for normal validation mode", err=True) + sys.exit(1) + if not flock_api_key: + click.echo("Error: --flock-api-key is required for normal validation mode", err=True) + sys.exit(1) + + runner = ValidationRunner( + module=module, + task_ids=task_ids.split(","), + flock_api_key=flock_api_key, + hf_token=hf_token, + time_sleep=time_sleep, + assignment_lookup_interval=assignment_lookup_interval, + debug=debug, + ) + try: + runner.run() + except KeyboardInterrupt: + click.echo("\nValidation interrupted by user.") + sys.exit(0) if __name__ == "__main__": main() diff --git a/validator/conda.py b/validator/conda.py index 4bba2df..89b7963 100644 --- a/validator/conda.py +++ b/validator/conda.py @@ -47,22 +47,36 @@ def install_in_env(env_name: str, packages: List[str]): """Install additional packages into an existing conda environment.""" run_command(["conda", "install", "-y", "-n", env_name] + packages) -def run_in_env(env_name: str, command: List[str], env_vars: Optional[dict] = None): +def run_in_env(env_name: str, command: List[str], env_vars: Optional[dict] = None, no_capture_output: bool = False): """Run a command inside a conda environment.""" logger.info(f"Running command {command} in environment {env_name}") - cmd = ["conda", "run", "-n", env_name] + command - run_command(cmd, env=env_vars) + cmd = ["conda", "run"] + if no_capture_output: + cmd.append("--no-capture-output") + cmd.extend(["-n", env_name] + command) + + try: + run_command(cmd, env=env_vars) + except subprocess.CalledProcessError as e: + # If --no-capture-output flag caused the error (micromamba compatibility), retry without it + if no_capture_output and e.returncode == 2: + logger.warning("--no-capture-output not supported, retrying without it") + cmd_without_flag = ["conda", "run", "-n", env_name] + command + run_command(cmd_without_flag, env=env_vars) + else: + raise def ensure_env_and_run( env_name: str, env_yml: Path, requirements_txt: Path, command: List[str], - env_vars: Optional[dict] = None + env_vars: Optional[dict] = None, + no_capture_output: bool = False ): """Ensure the environment exists, create if needed, then run the command.""" if not env_exists(env_name): create_env(env_name, env_yml, requirements_txt) else: update_env(env_name, env_yml, requirements_txt) - run_in_env(env_name, command, env_vars) \ No newline at end of file + run_in_env(env_name, command, env_vars, no_capture_output=no_capture_output) \ No newline at end of file diff --git a/validator/config.py b/validator/config.py index d888b29..0a6e1d1 100644 --- a/validator/config.py +++ b/validator/config.py @@ -29,3 +29,19 @@ def load_config_for_task( # 3. Use Pydantic model defaults for missing values return config_model(**config_data) + + +def load_config_from_file( + task_type: str, + config_model: Type[BaseConfig], + config_dir: str = "configs" +) -> BaseConfig: + """ + Loads config for a given task_type from file (for local testing). + """ + config_data: dict[str, Any] = {} + type_config_path = Path(config_dir) / f"{task_type}.json" + if type_config_path.exists(): + with open(type_config_path, "r") as f: + config_data.update(json.load(f)) + return config_model(**config_data) \ No newline at end of file diff --git a/validator/modules/rl/__init__.py b/validator/modules/rl/__init__.py index 890dc30..7f4d0b9 100644 --- a/validator/modules/rl/__init__.py +++ b/validator/modules/rl/__init__.py @@ -102,14 +102,24 @@ def validate(self, data: RLInputData, **kwargs) -> RLMetrics: return RLMetrics(average_reward=LOWEST_POSSIBLE_REWARD) # Download and load test data (.npz file containing X_test and Info_test) - print(f"Downloading test data from {data.validation_set_url}") - response = requests.get(data.validation_set_url, timeout=10) - response.raise_for_status() - - # Load the .npz file and extract X_test and Info_test - with np.load(BytesIO(response.content)) as test_data: - test_X = test_data['X'] - test_Info = test_data['Info'] + # Handle both local files (file://) and remote URLs + validation_url = data.validation_set_url + if validation_url.startswith("file://"): + # Local file path + local_path = validation_url[7:] # Remove 'file://' prefix + print(f"Loading test data from local file: {local_path}") + with np.load(local_path) as test_data: + test_X = test_data['X'] + test_Info = test_data['Info'] + else: + # Remote URL + print(f"Downloading test data from {validation_url}") + response = requests.get(validation_url, timeout=10) + response.raise_for_status() + # Load the .npz file and extract X_test and Info_test + with np.load(BytesIO(response.content)) as test_data: + test_X = test_data['X'] + test_Info = test_data['Info'] print(f"Loaded test data: X_test {test_X.shape}, Info_test {test_Info.shape}") diff --git a/validator/validation_runner.py b/validator/validation_runner.py index 56dc64b..9f43a03 100644 --- a/validator/validation_runner.py +++ b/validator/validation_runner.py @@ -1,10 +1,11 @@ import importlib import time import sys +import json from loguru import logger from .exceptions import RecoverableException from .api import FedLedger -from .config import load_config_for_task +from .config import load_config_for_task, load_config_from_file from .modules.base import BaseValidationModule, BaseConfig, BaseInputData, BaseMetrics class ValidationRunner: @@ -15,33 +16,39 @@ class ValidationRunner: def __init__( self, module: str, - task_ids: list[str], - flock_api_key: str, - hf_token: str, + task_ids: list[str] = None, + flock_api_key: str = None, + hf_token: str = None, time_sleep: int = 180, assignment_lookup_interval: int = 180, debug: bool = False, + local_test: bool = False, ): """ Initialize the ValidationRunner. Args: module: The name of the validation module to use. - task_ids: List of task IDs to validate. - flock_api_key: API key for Flock. - hf_token: HuggingFace token (passed for compatibility, not used here). + task_ids: List of task IDs to validate + flock_api_key: API key for Flock + hf_token: HuggingFace token time_sleep: Time to sleep between retries (seconds). assignment_lookup_interval: Assignment lookup interval (seconds). debug: Enable debug mode (currently unused). + local_test: If True, skip API initialization for local testing. """ self.module = module - self.task_ids = task_ids + self.task_ids = task_ids or [] self.flock_api_key = flock_api_key self.hf_token = hf_token self.time_sleep = time_sleep self.assignment_lookup_interval = assignment_lookup_interval self.debug = debug - self.api = FedLedger(flock_api_key) - self._setup_modules() + self.local_test = local_test + if not local_test: + self.api = FedLedger(flock_api_key) + self._setup_modules() + else: + self._setup_local_module() def _setup_modules(self): """Dynamically import and initialize validation modules for each task.""" @@ -60,6 +67,108 @@ def _setup_modules(self): self.module_config_to_module.setdefault(config, module_cls(config=config)) self.task_id_to_module[task_id] = self.module_config_to_module[config] + def _setup_local_module(self): + """Setup module for local testing without API.""" + module_mod = importlib.import_module(f"validator.modules.{self.module}") + self.module_cls: type[BaseValidationModule] = module_mod.MODULE + self.local_config = load_config_from_file(self.module, self.module_cls.config_schema) + self.local_module = self.module_cls(config=self.local_config) + + def run_local_test( + self, + hf_repo: str, + revision: str = "main", + validation_set_path: str = None, + model_filename: str = "model.onnx", + max_params: int = 100000000, + ): + """ + Run a local validation test without submitting to FedLedger. + + Args: + hf_repo: HuggingFace repository ID (e.g., 'username/model-name') + revision: Git revision/commit ID (default: 'main') + validation_set_path: Path to local validation set file (e.g., validation_set.npz) + model_filename: Model filename in the repo (default: 'model.onnx') + max_params: Maximum allowed model parameters (default: 100M) + """ + logger.info("=" * 60) + logger.info("LOCAL VALIDATION TEST") + logger.info("(Runs EXACT same code as real validator)") + logger.info("=" * 60) + logger.info("") + logger.info("CONFIG (from configs/{}.json):".format(self.module)) + logger.info(json.dumps(self.local_config.model_dump(), indent=2)) + logger.info("") + logger.info("SUBMISSION PARAMETERS (what model submitter provides):") + logger.info(f" HuggingFace Repo: {hf_repo}") + logger.info(f" Revision: {revision}") + logger.info(f" Model Filename: {model_filename}") + logger.info("") + logger.info("TASK PARAMETERS (must match task configuration for accurate results):") + logger.info(f" Validation Set: {validation_set_path}") + logger.info(f" Max Parameters: {max_params}") + logger.info("=" * 60) + + # Build input data based on module type + if self.module == "rl": + # For RL module, we need to handle the validation_set_url + # If a local path is provided, we'll serve it or use file:// protocol + if validation_set_path: + import os + if not os.path.isabs(validation_set_path): + validation_set_path = os.path.abspath(validation_set_path) + validation_set_url = f"file://{validation_set_path}" + else: + raise ValueError("validation_set_path is required for RL module") + + input_data = self.local_module.input_data_schema( + hg_repo_id=hf_repo, + model_filename=model_filename, + revision=revision, + validation_set_url=validation_set_url, + max_params=max_params, + ) + else: + # Generic fallback - try to construct input data + input_data = self.local_module.input_data_schema( + hg_repo_id=hf_repo, + revision=revision, + ) + + logger.info("Input data constructed:") + logger.info(json.dumps(input_data.model_dump(), indent=2)) + logger.info("-" * 60) + + try: + logger.info("Starting validation...") + metrics = self.local_module.validate(input_data) + + logger.info("") + logger.info("=" * 60) + logger.info("VALIDATION COMPLETE - SUCCESS") + logger.info("=" * 60) + logger.info("") + logger.info("RESULT (exactly what would be submitted to FedLedger):") + logger.info(json.dumps(metrics.model_dump(), indent=2)) + logger.info("") + logger.info("NOTE: If using the same validation_set and max_params as the") + logger.info("real task, this score will match what validators compute.") + logger.info("=" * 60) + + return metrics + + except Exception as e: + logger.error(f"Validation failed with error: {e}") + logger.error("") + logger.error("This model would FAIL validation on the real task.") + logger.error("Common causes:") + logger.error(" - Model file not found in HuggingFace repo") + logger.error(" - Model exceeds max_params limit") + logger.error(" - Model input/output shape mismatch") + logger.error(" - Invalid ONNX model format") + raise + def perform_validation(self, assignment_id: str, task_id: str,input_data: BaseInputData) -> BaseMetrics | None: """ Perform validation for a given assignment and input data.