Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 57 additions & 50 deletions primus/core/config/primus_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,60 @@

from __future__ import annotations

from copy import copy, deepcopy
from pathlib import Path
from types import SimpleNamespace
from typing import Any

from primus.core.launcher.parser import PrimusParser
from primus.core.utils import constant_vars
from primus.core.utils.yaml_utils import dict_to_nested_namespace
from primus.core.utils.yaml_utils import (
dict_to_nested_namespace,
nested_namespace_to_dict,
)


def _to_plain_dict(value: Any) -> dict:
if value is None:
return {}
if isinstance(value, SimpleNamespace):
return nested_namespace_to_dict(value)
return dict(value)


def _normalize_module_for_runtime(module_cfg: SimpleNamespace, module_name: str) -> SimpleNamespace:
"""
Normalize a legacy module namespace into the runtime shape:
- Keep a small set of reserved top-level attributes:
* name, framework, config, model
* (and any existing 'params' if present)
- Move all other public attributes into a `params` dict:
module_cfg.params[key] = <original value>
"""
normalized = deepcopy(module_cfg)
# Ensure each module has a stable `.name` attribute.
if not getattr(normalized, "name", None):
setattr(normalized, "name", module_name)

reserved_keys = {"name", "framework", "config", "model", "params"}
# Start from any existing params dict/namespace if provided.
existing_params = getattr(normalized, "params", {})
params = _to_plain_dict(existing_params)

for key, value in list(vars(normalized).items()):
# Skip reserved and private attributes.
if key in reserved_keys or key.startswith("_"):
continue
# Move this attribute into params and remove it from the namespace.
params[key] = value
delattr(normalized, key)

# Convert params dict to nested namespace for attribute-style access.
normalized.params = dict_to_nested_namespace(params)
# Duplicate `model` under `params.model` for downstream compatibility.
if hasattr(normalized, "model"):
normalized.params.model = normalized.model
return normalized


def load_primus_config(config_path: Path, cli_args: Any | None = None) -> SimpleNamespace:
Expand All @@ -43,15 +90,12 @@ def load_primus_config(config_path: Path, cli_args: Any | None = None) -> Simple
# PrimusParser.parse() only requires a `.config` attribute.
if cli_args is None:
args_for_parser = SimpleNamespace(config=config_path_str)
else:
elif getattr(cli_args, "config", None) != config_path_str:
# Avoid mutating the original args object.
if getattr(cli_args, "config", None) != config_path_str:
from copy import copy

args_for_parser = copy(cli_args)
setattr(args_for_parser, "config", config_path_str)
else:
args_for_parser = cli_args
args_for_parser = copy(cli_args)
setattr(args_for_parser, "config", config_path_str)
else:
args_for_parser = cli_args

# Use legacy PrimusParser to build a PrimusConfig instance.
legacy_cfg = PrimusParser().parse(args_for_parser)
Expand All @@ -77,47 +121,10 @@ def load_primus_config(config_path: Path, cli_args: Any | None = None) -> Simple
cfg.platform = platform_config

# Build modules list from legacy PrimusConfig.module_keys/get_module_config.
modules: list[SimpleNamespace] = []
for module_name in getattr(legacy_cfg, "module_keys", []):
module_cfg = legacy_cfg.get_module_config(module_name)

# Ensure each module has a stable `.name` attribute.
if not getattr(module_cfg, "name", None):
setattr(module_cfg, "name", module_name)

# Normalize module configuration layout for the new core runtime:
#
# - Keep a small set of reserved top-level attributes:
# * name, framework, config, model
# * (and any existing 'params' if present)
# - Move all other public attributes into a `params` dict:
# module_cfg.params[key] = <original value>
#
# This matches the expectation of downstream components which
# treat `module_cfg.params` as the flat parameter dictionary.
reserved_keys = {"name", "framework", "config", "model", "params"}

# Start from any existing params dict if present.
params = dict(getattr(module_cfg, "params", {}))

for key, value in list(vars(module_cfg).items()):
# Skip reserved and private attributes.
if key in reserved_keys or key.startswith("_"):
continue

# Move this attribute into params and remove it from the namespace.
params[key] = value
delattr(module_cfg, key)

# Convert params dict to a SimpleNamespace tree for attribute-style access
module_cfg.params = dict_to_nested_namespace(params)
# Duplicate `model` under `params.model` because some downstream components
# expect the model configuration to be available via `params.model`.
module_cfg.params.model = module_cfg.model

modules.append(module_cfg)

cfg.modules = modules
cfg.modules = [
_normalize_module_for_runtime(legacy_cfg.get_module_config(module_name), module_name)
for module_name in getattr(legacy_cfg, "module_keys", [])
]

return cfg

Expand Down
131 changes: 40 additions & 91 deletions primus/core/launcher/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,30 @@
from primus.core.config.preset_loader import PresetLoader
from primus.core.launcher.config import PrimusConfig
from primus.core.utils import constant_vars, yaml_utils
from primus.core.utils.arg_utils import parse_cli_overrides


def add_pretrain_parser(parser: argparse.ArgumentParser):
def _add_common_train_args(
parser: argparse.ArgumentParser, include_data_path: bool
) -> argparse.ArgumentParser:
"""
Register shared train CLI arguments.
"""
parser.add_argument(
"--config",
"--exp",
dest="config",
type=str,
required=True,
help="Path to experiment YAML config file (alias: --exp)",
)
parser.add_argument(
"--data_path",
type=str,
default="./data",
help="Path to data directory [default: ./data]",
)
if include_data_path:
parser.add_argument(
"--data_path",
type=str,
default="./data",
help="Path to data directory [default: ./data]",
)
parser.add_argument(
"--backend_path",
nargs="?",
Expand All @@ -41,6 +50,10 @@ def add_pretrain_parser(parser: argparse.ArgumentParser):
return parser


def add_pretrain_parser(parser: argparse.ArgumentParser):
return _add_common_train_args(parser, include_data_path=True)


def add_posttrain_parser(parser: argparse.ArgumentParser):
"""
Post-training (SFT / alignment) workflow parser.
Expand All @@ -53,29 +66,7 @@ def add_posttrain_parser(parser: argparse.ArgumentParser):

def _parse_args(extra_args_provider=None, ignore_unknown_args=False) -> Tuple[argparse.Namespace, List[str]]:
parser = argparse.ArgumentParser(description="Primus Arguments", allow_abbrev=False)

parser.add_argument(
"--config",
"--exp",
dest="exp",
type=str,
required=True,
help="Path to experiment YAML config file (alias: --exp)",
)
parser.add_argument(
"--backend_path",
nargs="?",
default=None,
help=(
"Optional backend import path for Megatron or TorchTitan. "
"If provided, it will be appended to PYTHONPATH dynamically."
),
)
parser.add_argument(
"--export_config",
type=str,
help="Optional path to export the final merged config to a file.",
)
parser = _add_common_train_args(parser, include_data_path=False)

# Custom arguments.
if extra_args_provider is not None:
Expand All @@ -86,69 +77,18 @@ def _parse_args(extra_args_provider=None, ignore_unknown_args=False) -> Tuple[ar

def _parse_kv_overrides(args: list[str]) -> dict:
"""
Parse CLI arguments of the form:
--key=value
--key value
--flag (boolean True)
into a nested dictionary structure.
Backward-compatible wrapper around the shared CLI override parser.

Supports nested keys using dot notation, e.g., --a.b.c=1.
Keep this symbol for existing callers/tests in legacy paths.
"""
overrides = {}
i = 0
while i < len(args):
arg = args[i]
# Ignore non-option arguments (not starting with "--")
if not arg.startswith("--"):
i += 1
continue

# Strip the "--" prefix
key = arg[2:]

if "=" in key:
# Format: --key=value
key, val = key.split("=", 1)
elif i + 1 < len(args) and not args[i + 1].startswith("--"):
# Format: --key value
val = args[i + 1]
i += 1
else:
# Format: --flag (boolean True)
val = True

# Normalize common lowercase booleans before eval, e.g. "true"/"false".
if isinstance(val, str):
lower_val = val.lower()
if lower_val == "true":
val = True
elif lower_val == "false":
val = False
else:
# Try to evaluate the value to correct type (int, float, etc.)
try:
val = eval(val, {}, {})
except Exception:
pass # Leave as string if evaluation fails

# Handle nested keys, e.g., modules.pre_trainer.lr
d = overrides
keys = key.split(".")
for k in keys[:-1]:
d = d.setdefault(k, {})
d[keys[-1]] = val

i += 1

return overrides
return parse_cli_overrides(args, type_mode="legacy")


def _deep_merge_namespace(ns, override_dict):
for k, v in override_dict.items():
if hasattr(ns, k) and isinstance(getattr(ns, k), SimpleNamespace) and isinstance(v, dict):
_deep_merge_namespace(getattr(ns, k), v)
else:
setattr(ns, k, v)
"""
Merge overrides into SimpleNamespace via unified yaml merge path.
"""
yaml_utils.deep_merge_namespace(ns, override_dict)


def _check_keys_exist(ns: SimpleNamespace, overrides: dict, prefix=""):
Expand Down Expand Up @@ -193,15 +133,15 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
config_parser = PrimusParser()
primus_config = config_parser.parse(args)

overrides = _parse_kv_overrides(unknown_args)
overrides = parse_cli_overrides(unknown_args, type_mode="legacy")
pre_trainer_cfg = primus_config.get_module_config("pre_trainer")
_check_keys_exist(pre_trainer_cfg, overrides)
_deep_merge_namespace(pre_trainer_cfg, overrides)

return primus_config


def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]:
def _load_legacy_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]:
"""
Build the Primus configuration with optional command-line overrides.

Expand All @@ -218,7 +158,7 @@ def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[
primus_config = config_parser.parse(args)

# 2 Parse overrides from flat list to dict/namespace
override_ns = _parse_kv_overrides(overrides)
override_ns = parse_cli_overrides(overrides, type_mode="legacy")

# 3 Apply overrides to pre_trainer module config
pre_trainer_cfg = primus_config.get_module_config("pre_trainer")
Expand All @@ -237,6 +177,15 @@ def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[
return primus_config, unknown_overrides


def load_primus_config(args: argparse.Namespace, overrides: List[str]) -> Tuple[Any, Dict[str, Any]]:
"""
Legacy compatibility API.

Prefer `primus.core.config.primus_config.load_primus_config` in new code.
"""
return _load_legacy_primus_config(args, overrides)


class PrimusParser(object):
def __init__(self):
pass
Expand Down
Loading