diff --git a/primus/core/config/primus_config.py b/primus/core/config/primus_config.py index 0477e91b3..888f9c997 100644 --- a/primus/core/config/primus_config.py +++ b/primus/core/config/primus_config.py @@ -18,13 +18,60 @@ from __future__ import annotations +from copy import copy, deepcopy from pathlib import Path from types import SimpleNamespace from typing import Any from primus.core.launcher.parser import PrimusParser from primus.core.utils import constant_vars -from primus.core.utils.yaml_utils import dict_to_nested_namespace +from primus.core.utils.yaml_utils import ( + dict_to_nested_namespace, + nested_namespace_to_dict, +) + + +def _to_plain_dict(value: Any) -> dict: + if value is None: + return {} + if isinstance(value, SimpleNamespace): + return nested_namespace_to_dict(value) + return dict(value) + + +def _normalize_module_for_runtime(module_cfg: SimpleNamespace, module_name: str) -> SimpleNamespace: + """ + Normalize a legacy module namespace into the runtime shape: + - Keep a small set of reserved top-level attributes: + * name, framework, config, model + * (and any existing 'params' if present) + - Move all other public attributes into a `params` dict: + module_cfg.params[key] = + """ + normalized = deepcopy(module_cfg) + # Ensure each module has a stable `.name` attribute. + if not getattr(normalized, "name", None): + setattr(normalized, "name", module_name) + + reserved_keys = {"name", "framework", "config", "model", "params"} + # Start from any existing params dict/namespace if provided. + existing_params = getattr(normalized, "params", {}) + params = _to_plain_dict(existing_params) + + for key, value in list(vars(normalized).items()): + # Skip reserved and private attributes. + if key in reserved_keys or key.startswith("_"): + continue + # Move this attribute into params and remove it from the namespace. + params[key] = value + delattr(normalized, key) + + # Convert params dict to nested namespace for attribute-style access. + normalized.params = dict_to_nested_namespace(params) + # Duplicate `model` under `params.model` for downstream compatibility. + if hasattr(normalized, "model"): + normalized.params.model = normalized.model + return normalized def load_primus_config(config_path: Path, cli_args: Any | None = None) -> SimpleNamespace: @@ -43,15 +90,12 @@ def load_primus_config(config_path: Path, cli_args: Any | None = None) -> Simple # PrimusParser.parse() only requires a `.config` attribute. if cli_args is None: args_for_parser = SimpleNamespace(config=config_path_str) - else: + elif getattr(cli_args, "config", None) != config_path_str: # Avoid mutating the original args object. - if getattr(cli_args, "config", None) != config_path_str: - from copy import copy - - args_for_parser = copy(cli_args) - setattr(args_for_parser, "config", config_path_str) - else: - args_for_parser = cli_args + args_for_parser = copy(cli_args) + setattr(args_for_parser, "config", config_path_str) + else: + args_for_parser = cli_args # Use legacy PrimusParser to build a PrimusConfig instance. legacy_cfg = PrimusParser().parse(args_for_parser) @@ -77,47 +121,10 @@ def load_primus_config(config_path: Path, cli_args: Any | None = None) -> Simple cfg.platform = platform_config # Build modules list from legacy PrimusConfig.module_keys/get_module_config. - modules: list[SimpleNamespace] = [] - for module_name in getattr(legacy_cfg, "module_keys", []): - module_cfg = legacy_cfg.get_module_config(module_name) - - # Ensure each module has a stable `.name` attribute. - if not getattr(module_cfg, "name", None): - setattr(module_cfg, "name", module_name) - - # Normalize module configuration layout for the new core runtime: - # - # - Keep a small set of reserved top-level attributes: - # * name, framework, config, model - # * (and any existing 'params' if present) - # - Move all other public attributes into a `params` dict: - # module_cfg.params[key] = - # - # This matches the expectation of downstream components which - # treat `module_cfg.params` as the flat parameter dictionary. - reserved_keys = {"name", "framework", "config", "model", "params"} - - # Start from any existing params dict if present. - params = dict(getattr(module_cfg, "params", {})) - - for key, value in list(vars(module_cfg).items()): - # Skip reserved and private attributes. - if key in reserved_keys or key.startswith("_"): - continue - - # Move this attribute into params and remove it from the namespace. - params[key] = value - delattr(module_cfg, key) - - # Convert params dict to a SimpleNamespace tree for attribute-style access - module_cfg.params = dict_to_nested_namespace(params) - # Duplicate `model` under `params.model` because some downstream components - # expect the model configuration to be available via `params.model`. - module_cfg.params.model = module_cfg.model - - modules.append(module_cfg) - - cfg.modules = modules + cfg.modules = [ + _normalize_module_for_runtime(legacy_cfg.get_module_config(module_name), module_name) + for module_name in getattr(legacy_cfg, "module_keys", []) + ] return cfg diff --git a/primus/core/launcher/parser.py b/primus/core/launcher/parser.py index 73337f118..2a0d0f218 100644 --- a/primus/core/launcher/parser.py +++ b/primus/core/launcher/parser.py @@ -9,21 +9,30 @@ from primus.core.config.preset_loader import PresetLoader from primus.core.launcher.config import PrimusConfig from primus.core.utils import constant_vars, yaml_utils +from primus.core.utils.arg_utils import parse_cli_overrides -def add_pretrain_parser(parser: argparse.ArgumentParser): +def _add_common_train_args( + parser: argparse.ArgumentParser, include_data_path: bool +) -> argparse.ArgumentParser: + """ + Register shared train CLI arguments. + """ parser.add_argument( "--config", + "--exp", + dest="config", type=str, required=True, help="Path to experiment YAML config file (alias: --exp)", ) - parser.add_argument( - "--data_path", - type=str, - default="./data", - help="Path to data directory [default: ./data]", - ) + if include_data_path: + parser.add_argument( + "--data_path", + type=str, + default="./data", + help="Path to data directory [default: ./data]", + ) parser.add_argument( "--backend_path", nargs="?", @@ -41,6 +50,10 @@ def add_pretrain_parser(parser: argparse.ArgumentParser): return parser +def add_pretrain_parser(parser: argparse.ArgumentParser): + return _add_common_train_args(parser, include_data_path=True) + + def add_posttrain_parser(parser: argparse.ArgumentParser): """ Post-training (SFT / alignment) workflow parser. @@ -53,29 +66,7 @@ def add_posttrain_parser(parser: argparse.ArgumentParser): def _parse_args(extra_args_provider=None, ignore_unknown_args=False) -> Tuple[argparse.Namespace, List[str]]: parser = argparse.ArgumentParser(description="Primus Arguments", allow_abbrev=False) - - parser.add_argument( - "--config", - "--exp", - dest="exp", - type=str, - required=True, - help="Path to experiment YAML config file (alias: --exp)", - ) - parser.add_argument( - "--backend_path", - nargs="?", - default=None, - help=( - "Optional backend import path for Megatron or TorchTitan. " - "If provided, it will be appended to PYTHONPATH dynamically." - ), - ) - parser.add_argument( - "--export_config", - type=str, - help="Optional path to export the final merged config to a file.", - ) + parser = _add_common_train_args(parser, include_data_path=False) # Custom arguments. if extra_args_provider is not None: @@ -86,69 +77,18 @@ def _parse_args(extra_args_provider=None, ignore_unknown_args=False) -> Tuple[ar def _parse_kv_overrides(args: list[str]) -> dict: """ - Parse CLI arguments of the form: - --key=value - --key value - --flag (boolean True) - into a nested dictionary structure. + Backward-compatible wrapper around the shared CLI override parser. - Supports nested keys using dot notation, e.g., --a.b.c=1. + Keep this symbol for existing callers/tests in legacy paths. """ - overrides = {} - i = 0 - while i < len(args): - arg = args[i] - # Ignore non-option arguments (not starting with "--") - if not arg.startswith("--"): - i += 1 - continue - - # Strip the "--" prefix - key = arg[2:] - - if "=" in key: - # Format: --key=value - key, val = key.split("=", 1) - elif i + 1 < len(args) and not args[i + 1].startswith("--"): - # Format: --key value - val = args[i + 1] - i += 1 - else: - # Format: --flag (boolean True) - val = True - - # Normalize common lowercase booleans before eval, e.g. "true"/"false". - if isinstance(val, str): - lower_val = val.lower() - if lower_val == "true": - val = True - elif lower_val == "false": - val = False - else: - # Try to evaluate the value to correct type (int, float, etc.) - try: - val = eval(val, {}, {}) - except Exception: - pass # Leave as string if evaluation fails - - # Handle nested keys, e.g., modules.pre_trainer.lr - d = overrides - keys = key.split(".") - for k in keys[:-1]: - d = d.setdefault(k, {}) - d[keys[-1]] = val - - i += 1 - - return overrides + return parse_cli_overrides(args, type_mode="legacy") def _deep_merge_namespace(ns, override_dict): - for k, v in override_dict.items(): - if hasattr(ns, k) and isinstance(getattr(ns, k), SimpleNamespace) and isinstance(v, dict): - _deep_merge_namespace(getattr(ns, k), v) - else: - setattr(ns, k, v) + """ + Merge overrides into SimpleNamespace via unified yaml merge path. + """ + yaml_utils.deep_merge_namespace(ns, override_dict) def _check_keys_exist(ns: SimpleNamespace, overrides: dict, prefix=""): @@ -193,7 +133,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): config_parser = PrimusParser() primus_config = config_parser.parse(args) - overrides = _parse_kv_overrides(unknown_args) + overrides = parse_cli_overrides(unknown_args, type_mode="legacy") pre_trainer_cfg = primus_config.get_module_config("pre_trainer") _check_keys_exist(pre_trainer_cfg, overrides) _deep_merge_namespace(pre_trainer_cfg, overrides) @@ -201,7 +141,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): return primus_config -def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]: +def _load_legacy_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]: """ Build the Primus configuration with optional command-line overrides. @@ -218,7 +158,7 @@ def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[ primus_config = config_parser.parse(args) # 2 Parse overrides from flat list to dict/namespace - override_ns = _parse_kv_overrides(overrides) + override_ns = parse_cli_overrides(overrides, type_mode="legacy") # 3 Apply overrides to pre_trainer module config pre_trainer_cfg = primus_config.get_module_config("pre_trainer") @@ -237,6 +177,15 @@ def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[ return primus_config, unknown_overrides +def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]: + """ + Legacy compatibility API. + + Prefer `primus.core.config.primus_config.load_primus_config` in new code. + """ + return _load_legacy_primus_config(args, overrides) + + class PrimusParser(object): def __init__(self): pass diff --git a/primus/core/utils/arg_utils.py b/primus/core/utils/arg_utils.py index 90e41c207..03e7fb74f 100644 --- a/primus/core/utils/arg_utils.py +++ b/primus/core/utils/arg_utils.py @@ -9,7 +9,46 @@ """ -def parse_cli_overrides(overrides: list) -> dict: +def _coerce_cli_value_modern(raw_value): + """Convert common CLI literals to bool/int/float/string.""" + value = raw_value + try: + if value.lower() in ("true", "false"): + return value.lower() == "true" + if "." in value: + try: + return float(value) + except ValueError: + pass + try: + return int(value) + except ValueError: + return value + except AttributeError: + return value + + +def _coerce_cli_value_legacy(raw_value): + """Convert CLI literals using legacy `_parse_kv_overrides` behavior.""" + value = raw_value + if not isinstance(value, str): + return value + + lower_val = value.lower() + if lower_val == "true": + return True + if lower_val == "false": + return False + + # Keep compatibility with legacy `_parse_kv_overrides`, which used eval + # for non-boolean values (e.g., None, lists, dicts, quoted strings). + try: + return eval(value, {}, {}) + except Exception: + return value + + +def parse_cli_overrides(overrides: list, type_mode: str = "modern") -> dict: """ Parse CLI override arguments. @@ -23,6 +62,9 @@ def parse_cli_overrides(overrides: list) -> dict: overrides: List of raw CLI override tokens, e.g.: ["lr=0.001", "batch_size=32"] ["--train_iters", "10"] + type_mode: Type inference strategy. + - "modern": bool/int/float/string (original parse_cli_overrides behavior) + - "legacy": bool + eval fallback (old _parse_kv_overrides behavior) Returns: Dictionary with parsed key-value pairs @@ -38,21 +80,28 @@ def parse_cli_overrides(overrides: list) -> dict: {"use_cache": True, "verbose": False} Type Inference Rules: - - Boolean: "true"/"false" (case-insensitive) -> bool - - Integer: digits or negative digits -> int - - Float: contains decimal point -> float - - String: everything else remains as string + - modern: bool/int/float/string + - legacy: bool + eval fallback Nested Keys: - Dot notation creates nested dictionaries - "model.layers=24" becomes {"model": {"layers": 24}} - Multiple nested keys merge into the same parent dict """ - # First normalise tokens to "key=value" form. + # First normalize tokens to "key=value" form. normalized: list[str] = [] i = 0 while i < len(overrides): item = overrides[i] + if not isinstance(item, str): + normalized.append(str(item)) + i += 1 + continue + + item = item.strip() + if not item: + i += 1 + continue # Already in key=value form (including "--key=value") if "=" in item: @@ -64,43 +113,38 @@ def parse_cli_overrides(overrides: list) -> dict: continue # Handle "--key value" → "key=value" - if item.startswith("--") and i + 1 < len(overrides) and "=" not in overrides[i + 1]: + if ( + item.startswith("--") + and i + 1 < len(overrides) + and isinstance(overrides[i + 1], str) + and not overrides[i + 1].startswith("--") + ): key = item.lstrip("-") value = overrides[i + 1] normalized.append(f"{key}={value}") i += 2 continue + # Handle bare "--flag" as boolean true. + if item.startswith("--"): + key = item.lstrip("-") + normalized.append(f"{key}=true") + i += 1 + continue + # Fallback: invalid format, emit warning and skip print(f"[Primus] Warning: Skipping invalid override format: {item}") i += 1 + if type_mode not in ("modern", "legacy"): + raise ValueError(f"Unsupported type_mode: {type_mode}") + coerce_fn = _coerce_cli_value_modern if type_mode == "modern" else _coerce_cli_value_legacy + result: dict = {} for item in normalized: key, value = item.split("=", 1) key = key.strip() - value = value.strip() - - # Try to convert to appropriate type - try: - # Try boolean - if value.lower() in ("true", "false"): - value = value.lower() == "true" - # Try float (handles negative values as well) - elif "." in value: - try: - value = float(value) - except ValueError: - pass # Keep as string if float conversion fails - else: - # Fallback to integer parsing (including negative ints) - try: - value = int(value) - except ValueError: - pass # Keep as string if int conversion fails - except AttributeError: - # Non-string values are left as-is - pass + value = coerce_fn(value.strip()) # Handle nested keys (e.g., model.layers -> {"model": {"layers": ...}}) if "." in key: diff --git a/primus/core/utils/yaml_utils.py b/primus/core/utils/yaml_utils.py index 1b9aebed5..1f39b1f2f 100755 --- a/primus/core/utils/yaml_utils.py +++ b/primus/core/utils/yaml_utils.py @@ -10,6 +10,7 @@ import yaml +from primus.core.config.merge_utils import deep_merge from primus.core.config.yaml_loader import parse_yaml as _parse_yaml_core @@ -76,30 +77,54 @@ def set_value_by_key(namespace: SimpleNamespace, key: str, value, allow_override return setattr(namespace, key, value) +def _assign_namespace_from_dict_inplace(namespace: SimpleNamespace, src_dict: dict): + """ + Assign nested dict values into a namespace in place. + + Existing nested SimpleNamespace objects are updated recursively to keep + references stable where possible. + """ + for key, value in src_dict.items(): + if isinstance(value, dict): + current = getattr(namespace, key, None) + if isinstance(current, SimpleNamespace): + _assign_namespace_from_dict_inplace(current, value) + else: + setattr(namespace, key, dict_to_nested_namespace(value)) + else: + set_value_by_key(namespace, key, dict_to_nested_namespace(value), allow_override=True) + + +def deep_merge_namespace(namespace: SimpleNamespace, override_dict: dict): + """ + Apply dict-style deep merge into namespace in place. + """ + base_dict = nested_namespace_to_dict(namespace) + merged_dict = deep_merge(base_dict, override_dict) + _assign_namespace_from_dict_inplace(namespace, merged_dict) + return namespace + + def override_namespace(original_ns: SimpleNamespace, overrides_ns: SimpleNamespace): if overrides_ns is None: return - - for key in vars(overrides_ns): - # if not has_key_in_namespace(original_ns, key): - # raise Exception(f"Override namespace failed: can't find key({key}) in namespace {original_ns}") - new_value = get_value_by_key(overrides_ns, key) - if isinstance(new_value, SimpleNamespace): - override_namespace(get_value_by_key(original_ns, key), new_value) - else: - set_value_by_key(original_ns, key, new_value, allow_override=True) + deep_merge_namespace(original_ns, nested_namespace_to_dict(overrides_ns)) def merge_namespace(dst: SimpleNamespace, src: SimpleNamespace, allow_override=False, excepts: list = None): - src_dict = vars(src) - dst_dict = vars(dst) - excepts = excepts or [] + src_dict = nested_namespace_to_dict(src) + dst_dict = nested_namespace_to_dict(dst) + excepts = set(excepts or []) + + effective_override = {} for key, value in src_dict.items(): if key in excepts: continue if key in dst_dict and not allow_override: continue # Skip duplicate keys, keep dst value - setattr(dst, key, value) + effective_override[key] = value + + deep_merge_namespace(dst, effective_override) def dump_namespace_to_yaml(ns: SimpleNamespace, file_path: str):