Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 77 additions & 19 deletions environment_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,91 @@
@click.option(
"--task_ids",
type=str,
required=True,
help="The ids of the task, separated by comma",
required=False,
help="The ids of the task, separated by comma (not required for local test)",
)
@click.option('--flock-api-key', envvar='FLOCK_API_KEY', required=True, help='Flock API key')
@click.option('--hf-token', envvar='HF_TOKEN', required=True, help='HuggingFace token')
@click.option('--flock-api-key', envvar='FLOCK_API_KEY', required=False, help='Flock API key (not required for local test)')
@click.option('--hf-token', envvar='HF_TOKEN', required=False, help='HuggingFace token')
@click.option('--time-sleep', envvar='TIME_SLEEP', default=60 * 3, type=int, show_default=True, help='Time to sleep between retries (seconds)')
@click.option('--assignment-lookup-interval', envvar='ASSIGNMENT_LOOKUP_INTERVAL', default=60 * 3, type=int, show_default=True, help='Assignment lookup interval (seconds)')
@click.option("--debug", is_flag=True)
def main(module: str, task_ids: str, flock_api_key: str, hf_token: str, time_sleep: int, assignment_lookup_interval: int, debug: bool):
# Local test options
@click.option('--local-test', is_flag=True, help='Run a local validation test without submitting to FedLedger')
@click.option('--hf-repo', type=str, help='HuggingFace repository ID for local test (e.g., username/model-name)')
@click.option('--revision', type=str, default='main', help='Git revision/commit ID for local test (default: main)')
@click.option('--validation-set', type=str, help='Path to local validation set file for local test')
@click.option('--model-filename', type=str, default='model.onnx', help='Model filename in repo for local test (default: model.onnx)')
@click.option('--max-params', type=int, default=100000000, help='Maximum model parameters for local test (default: 100M)')
def main(
module: str,
task_ids: str,
flock_api_key: str,
hf_token: str,
time_sleep: int,
assignment_lookup_interval: int,
debug: bool,
local_test: bool,
hf_repo: str,
revision: str,
validation_set: str,
model_filename: str,
max_params: int,
):
"""
CLI entrypoint for running the validation process.
Delegates core logic to ValidationRunner.

For local testing, use --local-test flag with --hf-repo and --validation-set options.
Example: python environment_entrypoint.py rl --local-test --hf-repo username/model --validation-set data.npz
"""
runner = ValidationRunner(
module=module,
task_ids=task_ids.split(","),
flock_api_key=flock_api_key,
hf_token=hf_token,
time_sleep=time_sleep,
assignment_lookup_interval=assignment_lookup_interval,
debug=debug,
)
try:
runner.run()
except KeyboardInterrupt:
click.echo("\nValidation interrupted by user.")
sys.exit(0)
if local_test:
# Local test mode - no API required
if not hf_repo:
click.echo("Error: --hf-repo is required for local test mode", err=True)
sys.exit(1)

runner = ValidationRunner(
module=module,
local_test=True,
debug=debug,
)
try:
runner.run_local_test(
hf_repo=hf_repo,
revision=revision,
validation_set_path=validation_set,
model_filename=model_filename,
max_params=max_params,
)
except KeyboardInterrupt:
click.echo("\nLocal test interrupted by user.")
sys.exit(0)
except Exception as e:
click.echo(f"\nLocal test failed: {e}", err=True)
sys.exit(1)
else:
# Normal validation mode - requires API
if not task_ids:
click.echo("Error: --task_ids is required for normal validation mode", err=True)
sys.exit(1)
if not flock_api_key:
click.echo("Error: --flock-api-key is required for normal validation mode", err=True)
sys.exit(1)

runner = ValidationRunner(
module=module,
task_ids=task_ids.split(","),
flock_api_key=flock_api_key,
hf_token=hf_token,
time_sleep=time_sleep,
assignment_lookup_interval=assignment_lookup_interval,
debug=debug,
)
try:
runner.run()
except KeyboardInterrupt:
click.echo("\nValidation interrupted by user.")
sys.exit(0)

if __name__ == "__main__":
main()
24 changes: 19 additions & 5 deletions validator/conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,36 @@ def install_in_env(env_name: str, packages: List[str]):
"""Install additional packages into an existing conda environment."""
run_command(["conda", "install", "-y", "-n", env_name] + packages)

def run_in_env(env_name: str, command: List[str], env_vars: Optional[dict] = None):
def run_in_env(env_name: str, command: List[str], env_vars: Optional[dict] = None, no_capture_output: bool = False):
"""Run a command inside a conda environment."""
logger.info(f"Running command {command} in environment {env_name}")
cmd = ["conda", "run", "-n", env_name] + command
run_command(cmd, env=env_vars)
cmd = ["conda", "run"]
if no_capture_output:
cmd.append("--no-capture-output")
cmd.extend(["-n", env_name] + command)

try:
run_command(cmd, env=env_vars)
except subprocess.CalledProcessError as e:
# If --no-capture-output flag caused the error (micromamba compatibility), retry without it
if no_capture_output and e.returncode == 2:
logger.warning("--no-capture-output not supported, retrying without it")
cmd_without_flag = ["conda", "run", "-n", env_name] + command
run_command(cmd_without_flag, env=env_vars)
else:
raise

def ensure_env_and_run(
env_name: str,
env_yml: Path,
requirements_txt: Path,
command: List[str],
env_vars: Optional[dict] = None
env_vars: Optional[dict] = None,
no_capture_output: bool = False
):
"""Ensure the environment exists, create if needed, then run the command."""
if not env_exists(env_name):
create_env(env_name, env_yml, requirements_txt)
else:
update_env(env_name, env_yml, requirements_txt)
run_in_env(env_name, command, env_vars)
run_in_env(env_name, command, env_vars, no_capture_output=no_capture_output)
16 changes: 16 additions & 0 deletions validator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,19 @@ def load_config_for_task(

# 3. Use Pydantic model defaults for missing values
return config_model(**config_data)


def load_config_from_file(
task_type: str,
config_model: Type[BaseConfig],
config_dir: str = "configs"
) -> BaseConfig:
"""
Loads config for a given task_type from file (for local testing).
"""
config_data: dict[str, Any] = {}
type_config_path = Path(config_dir) / f"{task_type}.json"
if type_config_path.exists():
with open(type_config_path, "r") as f:
config_data.update(json.load(f))
return config_model(**config_data)
26 changes: 18 additions & 8 deletions validator/modules/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,24 @@ def validate(self, data: RLInputData, **kwargs) -> RLMetrics:
return RLMetrics(average_reward=LOWEST_POSSIBLE_REWARD)

# Download and load test data (.npz file containing X_test and Info_test)
print(f"Downloading test data from {data.validation_set_url}")
response = requests.get(data.validation_set_url, timeout=10)
response.raise_for_status()

# Load the .npz file and extract X_test and Info_test
with np.load(BytesIO(response.content)) as test_data:
test_X = test_data['X']
test_Info = test_data['Info']
# Handle both local files (file://) and remote URLs
validation_url = data.validation_set_url
if validation_url.startswith("file://"):
# Local file path
local_path = validation_url[7:] # Remove 'file://' prefix
print(f"Loading test data from local file: {local_path}")
with np.load(local_path) as test_data:
test_X = test_data['X']
test_Info = test_data['Info']
else:
# Remote URL
print(f"Downloading test data from {validation_url}")
response = requests.get(validation_url, timeout=10)
response.raise_for_status()
# Load the .npz file and extract X_test and Info_test
with np.load(BytesIO(response.content)) as test_data:
test_X = test_data['X']
test_Info = test_data['Info']

print(f"Loaded test data: X_test {test_X.shape}, Info_test {test_Info.shape}")

Expand Down
129 changes: 119 additions & 10 deletions validator/validation_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import importlib
import time
import sys
import json
from loguru import logger
from .exceptions import RecoverableException
from .api import FedLedger
from .config import load_config_for_task
from .config import load_config_for_task, load_config_from_file
from .modules.base import BaseValidationModule, BaseConfig, BaseInputData, BaseMetrics

class ValidationRunner:
Expand All @@ -15,33 +16,39 @@ class ValidationRunner:
def __init__(
self,
module: str,
task_ids: list[str],
flock_api_key: str,
hf_token: str,
task_ids: list[str] = None,
flock_api_key: str = None,
hf_token: str = None,
time_sleep: int = 180,
assignment_lookup_interval: int = 180,
debug: bool = False,
local_test: bool = False,
):
"""
Initialize the ValidationRunner.
Args:
module: The name of the validation module to use.
task_ids: List of task IDs to validate.
flock_api_key: API key for Flock.
hf_token: HuggingFace token (passed for compatibility, not used here).
task_ids: List of task IDs to validate
flock_api_key: API key for Flock
hf_token: HuggingFace token
time_sleep: Time to sleep between retries (seconds).
assignment_lookup_interval: Assignment lookup interval (seconds).
debug: Enable debug mode (currently unused).
local_test: If True, skip API initialization for local testing.
"""
self.module = module
self.task_ids = task_ids
self.task_ids = task_ids or []
self.flock_api_key = flock_api_key
self.hf_token = hf_token
self.time_sleep = time_sleep
self.assignment_lookup_interval = assignment_lookup_interval
self.debug = debug
self.api = FedLedger(flock_api_key)
self._setup_modules()
self.local_test = local_test
if not local_test:
self.api = FedLedger(flock_api_key)
self._setup_modules()
else:
self._setup_local_module()

def _setup_modules(self):
"""Dynamically import and initialize validation modules for each task."""
Expand All @@ -60,6 +67,108 @@ def _setup_modules(self):
self.module_config_to_module.setdefault(config, module_cls(config=config))
self.task_id_to_module[task_id] = self.module_config_to_module[config]

def _setup_local_module(self):
"""Setup module for local testing without API."""
module_mod = importlib.import_module(f"validator.modules.{self.module}")
self.module_cls: type[BaseValidationModule] = module_mod.MODULE
self.local_config = load_config_from_file(self.module, self.module_cls.config_schema)
self.local_module = self.module_cls(config=self.local_config)

def run_local_test(
self,
hf_repo: str,
revision: str = "main",
validation_set_path: str = None,
model_filename: str = "model.onnx",
max_params: int = 100000000,
):
"""
Run a local validation test without submitting to FedLedger.

Args:
hf_repo: HuggingFace repository ID (e.g., 'username/model-name')
revision: Git revision/commit ID (default: 'main')
validation_set_path: Path to local validation set file (e.g., validation_set.npz)
model_filename: Model filename in the repo (default: 'model.onnx')
max_params: Maximum allowed model parameters (default: 100M)
"""
logger.info("=" * 60)
logger.info("LOCAL VALIDATION TEST")
logger.info("(Runs EXACT same code as real validator)")
logger.info("=" * 60)
logger.info("")
logger.info("CONFIG (from configs/{}.json):".format(self.module))
logger.info(json.dumps(self.local_config.model_dump(), indent=2))
logger.info("")
logger.info("SUBMISSION PARAMETERS (what model submitter provides):")
logger.info(f" HuggingFace Repo: {hf_repo}")
logger.info(f" Revision: {revision}")
logger.info(f" Model Filename: {model_filename}")
logger.info("")
logger.info("TASK PARAMETERS (must match task configuration for accurate results):")
logger.info(f" Validation Set: {validation_set_path}")
logger.info(f" Max Parameters: {max_params}")
logger.info("=" * 60)

# Build input data based on module type
if self.module == "rl":
# For RL module, we need to handle the validation_set_url
# If a local path is provided, we'll serve it or use file:// protocol
if validation_set_path:
import os
if not os.path.isabs(validation_set_path):
validation_set_path = os.path.abspath(validation_set_path)
validation_set_url = f"file://{validation_set_path}"
else:
raise ValueError("validation_set_path is required for RL module")

input_data = self.local_module.input_data_schema(
hg_repo_id=hf_repo,
model_filename=model_filename,
revision=revision,
validation_set_url=validation_set_url,
max_params=max_params,
)
else:
# Generic fallback - try to construct input data
input_data = self.local_module.input_data_schema(
hg_repo_id=hf_repo,
revision=revision,
)

logger.info("Input data constructed:")
logger.info(json.dumps(input_data.model_dump(), indent=2))
logger.info("-" * 60)

try:
logger.info("Starting validation...")
metrics = self.local_module.validate(input_data)

logger.info("")
logger.info("=" * 60)
logger.info("VALIDATION COMPLETE - SUCCESS")
logger.info("=" * 60)
logger.info("")
logger.info("RESULT (exactly what would be submitted to FedLedger):")
logger.info(json.dumps(metrics.model_dump(), indent=2))
logger.info("")
logger.info("NOTE: If using the same validation_set and max_params as the")
logger.info("real task, this score will match what validators compute.")
logger.info("=" * 60)

return metrics

except Exception as e:
logger.error(f"Validation failed with error: {e}")
logger.error("")
logger.error("This model would FAIL validation on the real task.")
logger.error("Common causes:")
logger.error(" - Model file not found in HuggingFace repo")
logger.error(" - Model exceeds max_params limit")
logger.error(" - Model input/output shape mismatch")
logger.error(" - Invalid ONNX model format")
raise

def perform_validation(self, assignment_id: str, task_id: str,input_data: BaseInputData) -> BaseMetrics | None:
"""
Perform validation for a given assignment and input data.
Expand Down