From 550412e25fb5f48d402cc2ae2b1c09d41c271339 Mon Sep 17 00:00:00 2001 From: WangLingxun Date: Tue, 3 Mar 2026 10:42:28 +0000 Subject: [PATCH] refactor(config): unify override parsing and simplify runtime config normalization Consolidate CLI override parsing into shared utilities and align parser behavior with legacy compatibility mode. Simplify runtime config assembly by extracting module normalization logic, improve namespace deep-merge semantics for YAML overrides, and preserve backward-compatible loader APIs while reducing duplicated parsing code. --- primus/core/config/primus_config.py | 107 ++++++++++++----------- primus/core/launcher/parser.py | 131 +++++++++------------------- primus/core/utils/arg_utils.py | 102 ++++++++++++++++------ primus/core/utils/yaml_utils.py | 51 ++++++++--- 4 files changed, 208 insertions(+), 183 deletions(-) diff --git a/primus/core/config/primus_config.py b/primus/core/config/primus_config.py index 0477e91b3..888f9c997 100644 --- a/primus/core/config/primus_config.py +++ b/primus/core/config/primus_config.py @@ -18,13 +18,60 @@ from __future__ import annotations +from copy import copy, deepcopy from pathlib import Path from types import SimpleNamespace from typing import Any from primus.core.launcher.parser import PrimusParser from primus.core.utils import constant_vars -from primus.core.utils.yaml_utils import dict_to_nested_namespace +from primus.core.utils.yaml_utils import ( + dict_to_nested_namespace, + nested_namespace_to_dict, +) + + +def _to_plain_dict(value: Any) -> dict: + if value is None: + return {} + if isinstance(value, SimpleNamespace): + return nested_namespace_to_dict(value) + return dict(value) + + +def _normalize_module_for_runtime(module_cfg: SimpleNamespace, module_name: str) -> SimpleNamespace: + """ + Normalize a legacy module namespace into the runtime shape: + - Keep a small set of reserved top-level attributes: + * name, framework, config, model + * (and any existing 'params' if present) + - Move all other public attributes into a `params` dict: + module_cfg.params[key] = + """ + normalized = deepcopy(module_cfg) + # Ensure each module has a stable `.name` attribute. + if not getattr(normalized, "name", None): + setattr(normalized, "name", module_name) + + reserved_keys = {"name", "framework", "config", "model", "params"} + # Start from any existing params dict/namespace if provided. + existing_params = getattr(normalized, "params", {}) + params = _to_plain_dict(existing_params) + + for key, value in list(vars(normalized).items()): + # Skip reserved and private attributes. + if key in reserved_keys or key.startswith("_"): + continue + # Move this attribute into params and remove it from the namespace. + params[key] = value + delattr(normalized, key) + + # Convert params dict to nested namespace for attribute-style access. + normalized.params = dict_to_nested_namespace(params) + # Duplicate `model` under `params.model` for downstream compatibility. + if hasattr(normalized, "model"): + normalized.params.model = normalized.model + return normalized def load_primus_config(config_path: Path, cli_args: Any | None = None) -> SimpleNamespace: @@ -43,15 +90,12 @@ def load_primus_config(config_path: Path, cli_args: Any | None = None) -> Simple # PrimusParser.parse() only requires a `.config` attribute. if cli_args is None: args_for_parser = SimpleNamespace(config=config_path_str) - else: + elif getattr(cli_args, "config", None) != config_path_str: # Avoid mutating the original args object. - if getattr(cli_args, "config", None) != config_path_str: - from copy import copy - - args_for_parser = copy(cli_args) - setattr(args_for_parser, "config", config_path_str) - else: - args_for_parser = cli_args + args_for_parser = copy(cli_args) + setattr(args_for_parser, "config", config_path_str) + else: + args_for_parser = cli_args # Use legacy PrimusParser to build a PrimusConfig instance. legacy_cfg = PrimusParser().parse(args_for_parser) @@ -77,47 +121,10 @@ def load_primus_config(config_path: Path, cli_args: Any | None = None) -> Simple cfg.platform = platform_config # Build modules list from legacy PrimusConfig.module_keys/get_module_config. - modules: list[SimpleNamespace] = [] - for module_name in getattr(legacy_cfg, "module_keys", []): - module_cfg = legacy_cfg.get_module_config(module_name) - - # Ensure each module has a stable `.name` attribute. - if not getattr(module_cfg, "name", None): - setattr(module_cfg, "name", module_name) - - # Normalize module configuration layout for the new core runtime: - # - # - Keep a small set of reserved top-level attributes: - # * name, framework, config, model - # * (and any existing 'params' if present) - # - Move all other public attributes into a `params` dict: - # module_cfg.params[key] = - # - # This matches the expectation of downstream components which - # treat `module_cfg.params` as the flat parameter dictionary. - reserved_keys = {"name", "framework", "config", "model", "params"} - - # Start from any existing params dict if present. - params = dict(getattr(module_cfg, "params", {})) - - for key, value in list(vars(module_cfg).items()): - # Skip reserved and private attributes. - if key in reserved_keys or key.startswith("_"): - continue - - # Move this attribute into params and remove it from the namespace. - params[key] = value - delattr(module_cfg, key) - - # Convert params dict to a SimpleNamespace tree for attribute-style access - module_cfg.params = dict_to_nested_namespace(params) - # Duplicate `model` under `params.model` because some downstream components - # expect the model configuration to be available via `params.model`. - module_cfg.params.model = module_cfg.model - - modules.append(module_cfg) - - cfg.modules = modules + cfg.modules = [ + _normalize_module_for_runtime(legacy_cfg.get_module_config(module_name), module_name) + for module_name in getattr(legacy_cfg, "module_keys", []) + ] return cfg diff --git a/primus/core/launcher/parser.py b/primus/core/launcher/parser.py index 73337f118..2a0d0f218 100644 --- a/primus/core/launcher/parser.py +++ b/primus/core/launcher/parser.py @@ -9,21 +9,30 @@ from primus.core.config.preset_loader import PresetLoader from primus.core.launcher.config import PrimusConfig from primus.core.utils import constant_vars, yaml_utils +from primus.core.utils.arg_utils import parse_cli_overrides -def add_pretrain_parser(parser: argparse.ArgumentParser): +def _add_common_train_args( + parser: argparse.ArgumentParser, include_data_path: bool +) -> argparse.ArgumentParser: + """ + Register shared train CLI arguments. + """ parser.add_argument( "--config", + "--exp", + dest="config", type=str, required=True, help="Path to experiment YAML config file (alias: --exp)", ) - parser.add_argument( - "--data_path", - type=str, - default="./data", - help="Path to data directory [default: ./data]", - ) + if include_data_path: + parser.add_argument( + "--data_path", + type=str, + default="./data", + help="Path to data directory [default: ./data]", + ) parser.add_argument( "--backend_path", nargs="?", @@ -41,6 +50,10 @@ def add_pretrain_parser(parser: argparse.ArgumentParser): return parser +def add_pretrain_parser(parser: argparse.ArgumentParser): + return _add_common_train_args(parser, include_data_path=True) + + def add_posttrain_parser(parser: argparse.ArgumentParser): """ Post-training (SFT / alignment) workflow parser. @@ -53,29 +66,7 @@ def add_posttrain_parser(parser: argparse.ArgumentParser): def _parse_args(extra_args_provider=None, ignore_unknown_args=False) -> Tuple[argparse.Namespace, List[str]]: parser = argparse.ArgumentParser(description="Primus Arguments", allow_abbrev=False) - - parser.add_argument( - "--config", - "--exp", - dest="exp", - type=str, - required=True, - help="Path to experiment YAML config file (alias: --exp)", - ) - parser.add_argument( - "--backend_path", - nargs="?", - default=None, - help=( - "Optional backend import path for Megatron or TorchTitan. " - "If provided, it will be appended to PYTHONPATH dynamically." - ), - ) - parser.add_argument( - "--export_config", - type=str, - help="Optional path to export the final merged config to a file.", - ) + parser = _add_common_train_args(parser, include_data_path=False) # Custom arguments. if extra_args_provider is not None: @@ -86,69 +77,18 @@ def _parse_args(extra_args_provider=None, ignore_unknown_args=False) -> Tuple[ar def _parse_kv_overrides(args: list[str]) -> dict: """ - Parse CLI arguments of the form: - --key=value - --key value - --flag (boolean True) - into a nested dictionary structure. + Backward-compatible wrapper around the shared CLI override parser. - Supports nested keys using dot notation, e.g., --a.b.c=1. + Keep this symbol for existing callers/tests in legacy paths. """ - overrides = {} - i = 0 - while i < len(args): - arg = args[i] - # Ignore non-option arguments (not starting with "--") - if not arg.startswith("--"): - i += 1 - continue - - # Strip the "--" prefix - key = arg[2:] - - if "=" in key: - # Format: --key=value - key, val = key.split("=", 1) - elif i + 1 < len(args) and not args[i + 1].startswith("--"): - # Format: --key value - val = args[i + 1] - i += 1 - else: - # Format: --flag (boolean True) - val = True - - # Normalize common lowercase booleans before eval, e.g. "true"/"false". - if isinstance(val, str): - lower_val = val.lower() - if lower_val == "true": - val = True - elif lower_val == "false": - val = False - else: - # Try to evaluate the value to correct type (int, float, etc.) - try: - val = eval(val, {}, {}) - except Exception: - pass # Leave as string if evaluation fails - - # Handle nested keys, e.g., modules.pre_trainer.lr - d = overrides - keys = key.split(".") - for k in keys[:-1]: - d = d.setdefault(k, {}) - d[keys[-1]] = val - - i += 1 - - return overrides + return parse_cli_overrides(args, type_mode="legacy") def _deep_merge_namespace(ns, override_dict): - for k, v in override_dict.items(): - if hasattr(ns, k) and isinstance(getattr(ns, k), SimpleNamespace) and isinstance(v, dict): - _deep_merge_namespace(getattr(ns, k), v) - else: - setattr(ns, k, v) + """ + Merge overrides into SimpleNamespace via unified yaml merge path. + """ + yaml_utils.deep_merge_namespace(ns, override_dict) def _check_keys_exist(ns: SimpleNamespace, overrides: dict, prefix=""): @@ -193,7 +133,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): config_parser = PrimusParser() primus_config = config_parser.parse(args) - overrides = _parse_kv_overrides(unknown_args) + overrides = parse_cli_overrides(unknown_args, type_mode="legacy") pre_trainer_cfg = primus_config.get_module_config("pre_trainer") _check_keys_exist(pre_trainer_cfg, overrides) _deep_merge_namespace(pre_trainer_cfg, overrides) @@ -201,7 +141,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): return primus_config -def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]: +def _load_legacy_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]: """ Build the Primus configuration with optional command-line overrides. @@ -218,7 +158,7 @@ def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[ primus_config = config_parser.parse(args) # 2 Parse overrides from flat list to dict/namespace - override_ns = _parse_kv_overrides(overrides) + override_ns = parse_cli_overrides(overrides, type_mode="legacy") # 3 Apply overrides to pre_trainer module config pre_trainer_cfg = primus_config.get_module_config("pre_trainer") @@ -237,6 +177,15 @@ def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[ return primus_config, unknown_overrides +def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]: + """ + Legacy compatibility API. + + Prefer `primus.core.config.primus_config.load_primus_config` in new code. + """ + return _load_legacy_primus_config(args, overrides) + + class PrimusParser(object): def __init__(self): pass diff --git a/primus/core/utils/arg_utils.py b/primus/core/utils/arg_utils.py index 90e41c207..03e7fb74f 100644 --- a/primus/core/utils/arg_utils.py +++ b/primus/core/utils/arg_utils.py @@ -9,7 +9,46 @@ """ -def parse_cli_overrides(overrides: list) -> dict: +def _coerce_cli_value_modern(raw_value): + """Convert common CLI literals to bool/int/float/string.""" + value = raw_value + try: + if value.lower() in ("true", "false"): + return value.lower() == "true" + if "." in value: + try: + return float(value) + except ValueError: + pass + try: + return int(value) + except ValueError: + return value + except AttributeError: + return value + + +def _coerce_cli_value_legacy(raw_value): + """Convert CLI literals using legacy `_parse_kv_overrides` behavior.""" + value = raw_value + if not isinstance(value, str): + return value + + lower_val = value.lower() + if lower_val == "true": + return True + if lower_val == "false": + return False + + # Keep compatibility with legacy `_parse_kv_overrides`, which used eval + # for non-boolean values (e.g., None, lists, dicts, quoted strings). + try: + return eval(value, {}, {}) + except Exception: + return value + + +def parse_cli_overrides(overrides: list, type_mode: str = "modern") -> dict: """ Parse CLI override arguments. @@ -23,6 +62,9 @@ def parse_cli_overrides(overrides: list) -> dict: overrides: List of raw CLI override tokens, e.g.: ["lr=0.001", "batch_size=32"] ["--train_iters", "10"] + type_mode: Type inference strategy. + - "modern": bool/int/float/string (original parse_cli_overrides behavior) + - "legacy": bool + eval fallback (old _parse_kv_overrides behavior) Returns: Dictionary with parsed key-value pairs @@ -38,21 +80,28 @@ def parse_cli_overrides(overrides: list) -> dict: {"use_cache": True, "verbose": False} Type Inference Rules: - - Boolean: "true"/"false" (case-insensitive) -> bool - - Integer: digits or negative digits -> int - - Float: contains decimal point -> float - - String: everything else remains as string + - modern: bool/int/float/string + - legacy: bool + eval fallback Nested Keys: - Dot notation creates nested dictionaries - "model.layers=24" becomes {"model": {"layers": 24}} - Multiple nested keys merge into the same parent dict """ - # First normalise tokens to "key=value" form. + # First normalize tokens to "key=value" form. normalized: list[str] = [] i = 0 while i < len(overrides): item = overrides[i] + if not isinstance(item, str): + normalized.append(str(item)) + i += 1 + continue + + item = item.strip() + if not item: + i += 1 + continue # Already in key=value form (including "--key=value") if "=" in item: @@ -64,43 +113,38 @@ def parse_cli_overrides(overrides: list) -> dict: continue # Handle "--key value" → "key=value" - if item.startswith("--") and i + 1 < len(overrides) and "=" not in overrides[i + 1]: + if ( + item.startswith("--") + and i + 1 < len(overrides) + and isinstance(overrides[i + 1], str) + and not overrides[i + 1].startswith("--") + ): key = item.lstrip("-") value = overrides[i + 1] normalized.append(f"{key}={value}") i += 2 continue + # Handle bare "--flag" as boolean true. + if item.startswith("--"): + key = item.lstrip("-") + normalized.append(f"{key}=true") + i += 1 + continue + # Fallback: invalid format, emit warning and skip print(f"[Primus] Warning: Skipping invalid override format: {item}") i += 1 + if type_mode not in ("modern", "legacy"): + raise ValueError(f"Unsupported type_mode: {type_mode}") + coerce_fn = _coerce_cli_value_modern if type_mode == "modern" else _coerce_cli_value_legacy + result: dict = {} for item in normalized: key, value = item.split("=", 1) key = key.strip() - value = value.strip() - - # Try to convert to appropriate type - try: - # Try boolean - if value.lower() in ("true", "false"): - value = value.lower() == "true" - # Try float (handles negative values as well) - elif "." in value: - try: - value = float(value) - except ValueError: - pass # Keep as string if float conversion fails - else: - # Fallback to integer parsing (including negative ints) - try: - value = int(value) - except ValueError: - pass # Keep as string if int conversion fails - except AttributeError: - # Non-string values are left as-is - pass + value = coerce_fn(value.strip()) # Handle nested keys (e.g., model.layers -> {"model": {"layers": ...}}) if "." in key: diff --git a/primus/core/utils/yaml_utils.py b/primus/core/utils/yaml_utils.py index 1b9aebed5..1f39b1f2f 100755 --- a/primus/core/utils/yaml_utils.py +++ b/primus/core/utils/yaml_utils.py @@ -10,6 +10,7 @@ import yaml +from primus.core.config.merge_utils import deep_merge from primus.core.config.yaml_loader import parse_yaml as _parse_yaml_core @@ -76,30 +77,54 @@ def set_value_by_key(namespace: SimpleNamespace, key: str, value, allow_override return setattr(namespace, key, value) +def _assign_namespace_from_dict_inplace(namespace: SimpleNamespace, src_dict: dict): + """ + Assign nested dict values into a namespace in place. + + Existing nested SimpleNamespace objects are updated recursively to keep + references stable where possible. + """ + for key, value in src_dict.items(): + if isinstance(value, dict): + current = getattr(namespace, key, None) + if isinstance(current, SimpleNamespace): + _assign_namespace_from_dict_inplace(current, value) + else: + setattr(namespace, key, dict_to_nested_namespace(value)) + else: + set_value_by_key(namespace, key, dict_to_nested_namespace(value), allow_override=True) + + +def deep_merge_namespace(namespace: SimpleNamespace, override_dict: dict): + """ + Apply dict-style deep merge into namespace in place. + """ + base_dict = nested_namespace_to_dict(namespace) + merged_dict = deep_merge(base_dict, override_dict) + _assign_namespace_from_dict_inplace(namespace, merged_dict) + return namespace + + def override_namespace(original_ns: SimpleNamespace, overrides_ns: SimpleNamespace): if overrides_ns is None: return - - for key in vars(overrides_ns): - # if not has_key_in_namespace(original_ns, key): - # raise Exception(f"Override namespace failed: can't find key({key}) in namespace {original_ns}") - new_value = get_value_by_key(overrides_ns, key) - if isinstance(new_value, SimpleNamespace): - override_namespace(get_value_by_key(original_ns, key), new_value) - else: - set_value_by_key(original_ns, key, new_value, allow_override=True) + deep_merge_namespace(original_ns, nested_namespace_to_dict(overrides_ns)) def merge_namespace(dst: SimpleNamespace, src: SimpleNamespace, allow_override=False, excepts: list = None): - src_dict = vars(src) - dst_dict = vars(dst) - excepts = excepts or [] + src_dict = nested_namespace_to_dict(src) + dst_dict = nested_namespace_to_dict(dst) + excepts = set(excepts or []) + + effective_override = {} for key, value in src_dict.items(): if key in excepts: continue if key in dst_dict and not allow_override: continue # Skip duplicate keys, keep dst value - setattr(dst, key, value) + effective_override[key] = value + + deep_merge_namespace(dst, effective_override) def dump_namespace_to_yaml(ns: SimpleNamespace, file_path: str):