Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/training/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,60 @@ When enabled, MLFlow receives:



### Comet ML

Megatron Bridge can log metrics and experiment metadata to Comet ML, following the same pattern as the W&B and MLFlow integrations.

#### What Gets Logged

When enabled, Comet ML receives:

- Training configuration as experiment parameters
- Scalar metrics (losses, learning rate, batch size, throughput, timers, memory, runtime, norms, energy, etc.)
- Validation loss and perplexity metrics
- Checkpoint save/load metadata

#### Enable Comet ML Logging

1) Install Comet ML:

```bash
pip install comet-ml
```

2) Authenticate:
- Either set `COMET_API_KEY` in the environment, or
- Pass an explicit `comet_api_key` in the logger config.

3) Configure logging in your training setup.

```python
from megatron.bridge.training.config import LoggerConfig

cfg.logger = LoggerConfig(
tensorboard_dir="./runs/tensorboard",
comet_project="my_project",
comet_experiment_name="llama32_1b_pretrain_run",
comet_workspace="my_workspace", # optional
comet_tags=["pretrain", "llama32"], # optional
)
```

```{note}
Comet ML is initialized lazily on the last rank when `comet_project` is set and `comet_experiment_name` is non-empty.
```

#### Comet ML Configuration with NeMo Run Launching

For users launching training scripts with NeMo Run, Comet ML can be optionally configured using the {py:class}`bridge.recipes.run_plugins.CometPlugin`.

The plugin automatically forwards the `COMET_API_KEY` and by default injects CLI overrides for the following logger parameters:

- `logger.comet_project`
- `logger.comet_workspace`
- `logger.comet_experiment_name`


#### Progress Log

When `logger.log_progress` is enabled, the framework generates a `progress.txt` file in the checkpoint save directory.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ dependencies = [
"timm",
"open-clip-torch>=3.2.0",
"mlflow>=3.5.0",
"comet-ml>=3.50.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think uv.lock need update, otherwise ci will fail.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks — added comet-ml to pyproject.toml but can't regenerate uv.lock locally since it needs the CI environment for GPU-specific package resolution. Happy to update if there's a way to run the lock step externally.

"torch>=2.6.0",
]

Expand Down
69 changes: 69 additions & 0 deletions src/megatron/bridge/recipes/run_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,75 @@ def setup(self, task: Union["run.Partial", "run.Script"], executor: "run.Executo
)


@dataclass
class CometPluginScriptArgs:
"""Arguments for CometPlugin to pass to run.Script."""

project: str
workspace: Optional[str]
name: Optional[str]


def _default_comet_converter(args: CometPluginScriptArgs) -> List[str]:
"""Default converter for CometPlugin that generates CLI overrides."""
cli_overrides = [f"logger.comet_project={args.project}"]
if args.workspace:
cli_overrides.append(f"logger.comet_workspace={args.workspace}")
if args.name:
cli_overrides.append(f"logger.comet_experiment_name={args.name}")
return cli_overrides


@dataclass(kw_only=True)
class CometPlugin(Plugin):
"""
A plugin for setting up Comet ML configuration.
This plugin sets up Comet ML logging configuration. The plugin is only activated
if the ``COMET_API_KEY`` environment variable is set.
The ``COMET_API_KEY`` environment variable will also be set in the executor's environment variables.
Follow https://www.comet.com/docs/v2/guides/getting-started/quickstart/ to retrieve your ``COMET_API_KEY``.
Args:
project (str): The Comet ML project name.
name (Optional[str]): The name for the Comet ML experiment.
workspace (Optional[str]): The Comet ML workspace.
script_args_converter_fn (Optional[Callable]): A function that takes CometPluginScriptArgs
and returns a list of CLI arguments.
"""

project: str
name: Optional[str] = None
workspace: Optional[str] = None
script_args_converter_fn: Optional[Callable[[CometPluginScriptArgs], List[str]]] = None

def setup(self, task: Union["run.Partial", "run.Script"], executor: "run.Executor"):
if not HAVE_NEMO_RUN:
raise ImportError(MISSING_NEMO_RUN_MSG)

if "COMET_API_KEY" in os.environ:
executor.env_vars["COMET_API_KEY"] = os.environ["COMET_API_KEY"]

if isinstance(task, Script):
script_args = CometPluginScriptArgs(
project=self.project,
workspace=self.workspace,
name=self.name,
)

converter = self.script_args_converter_fn or _default_comet_converter
cli_overrides = converter(script_args)

task.args.extend(cli_overrides)
logger.info(f"{self.__class__.__name__} added CLI overrides: {', '.join(cli_overrides)}")
else:
raise NotImplementedError("CometPlugin is only supported for run.Script tasks")
else:
logger.warning(
f"Warning: The {self.__class__.__name__} will have no effect as COMET_API_KEY environment variable is not set."
)


@dataclass
class PerfEnvPluginScriptArgs:
"""Arguments for PerfEnvPlugin to pass to run.Script."""
Expand Down
13 changes: 12 additions & 1 deletion src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from megatron.bridge.training.state import GlobalState, TrainState
from megatron.bridge.training.tokenizers.config import TokenizerConfig
from megatron.bridge.training.tokenizers.tokenizer import MegatronTokenizer
from megatron.bridge.training.utils import mlflow_utils, wandb_utils
from megatron.bridge.training.utils import comet_utils, mlflow_utils, wandb_utils
from megatron.bridge.training.utils.checkpoint_utils import (
checkpoint_exists,
ensure_directory_exists,
Expand Down Expand Up @@ -829,13 +829,23 @@ def mlflow_finalize_fn() -> None:
mlflow_logger=state.mlflow_logger,
)

def comet_finalize_fn() -> None:
comet_utils.on_save_checkpoint_success(
checkpoint_name,
save_dir,
train_state.step,
comet_logger=state.comet_logger,
)

if ckpt_cfg.async_save:
assert async_save_request is not None
async_save_request.add_finalize_fn(wandb_finalize_fn)
async_save_request.add_finalize_fn(mlflow_finalize_fn)
async_save_request.add_finalize_fn(comet_finalize_fn)
else:
wandb_finalize_fn()
mlflow_finalize_fn()
comet_finalize_fn()

if ckpt_cfg.async_save:
schedule_async_save(state, async_save_request)
Expand Down Expand Up @@ -1795,6 +1805,7 @@ def _load_checkpoint_from_path(
if not torch.distributed.is_initialized() or is_last_rank():
wandb_utils.on_load_checkpoint_success(checkpoint_name, load_dir, state.wandb_logger)
mlflow_utils.on_load_checkpoint_success(checkpoint_name, load_dir, state.mlflow_logger)
comet_utils.on_load_checkpoint_success(checkpoint_name, load_dir, state.comet_logger)

torch.cuda.empty_cache()

Expand Down
39 changes: 39 additions & 0 deletions src/megatron/bridge/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,21 @@ class LoggerConfig:
mlflow_tags: Optional[dict[str, str]] = None
"""Optional tags to apply to the MLFlow run."""

comet_project: Optional[str] = None
"""The Comet ML project name. Comet logging is disabled when this is None."""

comet_experiment_name: Optional[str] = None
"""The Comet ML experiment name."""

comet_workspace: Optional[str] = None
"""The Comet ML workspace. If not set, uses the default workspace for the API key."""

comet_api_key: Optional[str] = None
"""The Comet ML API key. Can also be set via COMET_API_KEY environment variable."""

comet_tags: Optional[list[str]] = None
"""Optional list of tags to apply to the Comet ML experiment."""

logging_level: int = logging.INFO
"""Set default logging level"""

Expand Down Expand Up @@ -1125,6 +1140,30 @@ def finalize(self) -> None:
"Install it via pip install mlflow or uv add mlflow"
) from exc

if self.comet_project and (self.comet_experiment_name is None or self.comet_experiment_name == ""):
raise ValueError("Set logger.comet_experiment_name when enabling Comet ML logging.")

using_comet = any(
[
self.comet_project,
self.comet_experiment_name,
self.comet_workspace,
self.comet_api_key,
self.comet_tags,
]
)

if using_comet:
try:
import importlib

importlib.import_module("comet_ml")
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Comet ML logging is configured, but the 'comet_ml' package is not installed. "
"Install it via pip install comet-ml or uv add comet-ml"
) from exc


@dataclass(kw_only=True)
class ProfilingConfig:
Expand Down
8 changes: 8 additions & 0 deletions src/megatron/bridge/training/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def evaluate_and_print_results(
writer = None

wandb_writer = state.wandb_logger
comet_logger = state.comet_logger

if should_fire(callback_manager, start_event):
callback_manager.fire(
Expand Down Expand Up @@ -386,6 +387,13 @@ def evaluate_and_print_results(
if state.cfg.logger.log_validation_ppl_to_tensorboard:
wandb_writer.log({"{} validation ppl".format(key): ppl}, state.train_state.step)

if comet_logger and is_last_rank():
comet_logger.log_metrics(
{"{} validation".format(key): total_loss_dict[key].item()}, step=state.train_state.step
)
if state.cfg.logger.log_validation_ppl_to_tensorboard:
comet_logger.log_metrics({"{} validation ppl".format(key): ppl}, step=state.train_state.step)

if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, state.train_state.step, writer)

Expand Down
23 changes: 13 additions & 10 deletions src/megatron/bridge/training/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
)



class SetupOutput(NamedTuple):
"""Represents the output of the main setup function.

Expand Down Expand Up @@ -80,6 +79,7 @@ class SetupOutput(NamedTuple):
checkpointing_context: dict[str, Any]
pg_collection: ProcessGroupCollection


def setup(
state: GlobalState,
train_valid_test_datasets_provider: Callable[..., tuple[Optional[Any], Optional[Any], Optional[Any]]],
Expand Down Expand Up @@ -244,13 +244,16 @@ def modelopt_pre_wrap_hook(model):

# For PEFT, the pretrained checkpoint is loaded in the pre-wrap hook
if cfg.peft is not None:
should_load_checkpoint = (cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load))
should_load_checkpoint = cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load)
if should_load_checkpoint:
# The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
# This is switched off here in order to load these states from the checkpoint
cfg.checkpoint.finetune = False
else:
should_load_checkpoint = (cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load)) or (cfg.checkpoint.pretrained_checkpoint is not None and checkpoint_exists(cfg.checkpoint.pretrained_checkpoint))
should_load_checkpoint = (cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load)) or (
cfg.checkpoint.pretrained_checkpoint is not None
and checkpoint_exists(cfg.checkpoint.pretrained_checkpoint)
)

if should_load_checkpoint:
timers("load-checkpoint", log_level=0).start(barrier=True)
Expand All @@ -271,6 +274,7 @@ def modelopt_pre_wrap_hook(model):
model,
state.tensorboard_logger,
state.wandb_logger,
comet_logger=state.comet_logger,
current_training_step=state.train_state.step,
)

Expand All @@ -288,9 +292,7 @@ def modelopt_pre_wrap_hook(model):
if "tokenizer" in inspect.signature(train_valid_test_datasets_provider).parameters:
train_valid_test_datasets_provider = partial(train_valid_test_datasets_provider, tokenizer=tokenizer)
if "pg_collection" in inspect.signature(train_valid_test_datasets_provider).parameters:
train_valid_test_datasets_provider = partial(
train_valid_test_datasets_provider, pg_collection=pg_collection
)
train_valid_test_datasets_provider = partial(train_valid_test_datasets_provider, pg_collection=pg_collection)

train_data_iterator, valid_data_iterator, test_data_iterator = setup_data_iterators(
cfg=cfg,
Expand Down Expand Up @@ -351,13 +353,13 @@ def _update_model_config_funcs(
if len(model) == 1:
model_config.param_sync_func = model_config.param_sync_func[0]
if optimizer is not None:
model_config.finalize_model_grads_func = partial(
finalize_model_grads, pg_collection=pg_collection
)
model_config.finalize_model_grads_func = partial(finalize_model_grads, pg_collection=pg_collection)
model_config.grad_scale_func = optimizer.scale_loss


def _create_peft_pre_wrap_hook(cfg: ConfigContainer, state: GlobalState) -> Callable[[list[MegatronModule]], list[MegatronModule]]:
def _create_peft_pre_wrap_hook(
cfg: ConfigContainer, state: GlobalState
) -> Callable[[list[MegatronModule]], list[MegatronModule]]:
"""Create a pre-wrap hook that handles PEFT logic.

This hook is executed before the model is wrapped with DDP/FSDP and handles:
Expand All @@ -371,6 +373,7 @@ def _create_peft_pre_wrap_hook(cfg: ConfigContainer, state: GlobalState) -> Call
Returns:
A callable hook that can be registered with the model provider
"""

def peft_pre_wrap_hook(model: list[MegatronModule]) -> list[MegatronModule]:
"""Pre-wrap hook that handles PEFT transformation.

Expand Down
Loading