Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,690 changes: 1,690 additions & 0 deletions scripts/translate_mlm_to_bridge.py

Large diffs are not rendered by default.

15 changes: 4 additions & 11 deletions src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
)

import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.activations import fast_gelu, squared_relu
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import (
Expand All @@ -61,6 +59,7 @@
)
from megatron.bridge.models.decorators.dispatch import dispatch
from megatron.bridge.models.model_provider import ModelProviderMixin
from megatron.bridge.utils.activation_map import ACTIVATION_FUNC_MAP
from megatron.bridge.utils.common_utils import print_rank_0


Expand Down Expand Up @@ -315,15 +314,9 @@ def mapping_registry(self) -> MegatronMappingRegistry:
("mscale_all_dim", "mscale_all_dim"),
]

# Common bidirectional activation function mapping: hf_name <-> megatron_func
ACTIVATION_MAPPING = {
"silu": F.silu,
"gelu": F.gelu,
"relu": F.relu,
"relu2": squared_relu,
"tanh": torch.tanh,
"gelu_pytorch_tanh": fast_gelu,
}
# Shared activation function mapping (hf_name <-> megatron_func).
# Single source of truth lives in megatron.bridge.utils.activation_map.
ACTIVATION_MAPPING = ACTIVATION_FUNC_MAP

@classmethod
def hf_to_megatron_activation(cls, hidden_act: str):
Expand Down
22 changes: 22 additions & 0 deletions src/megatron/bridge/models/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def _safe_asdict(obj, skip_keys: set[str]) -> dict:
return obj


def _resolve_string_fields(config: MCoreTransformerConfig) -> None:
"""Resolve string-valued fields to their runtime types.

Currently handles ``activation_func``: if it was set to a string
(e.g. via a CLI override like ``model.activation_func=silu``), it is
resolved to the corresponding callable before MCore post-init runs.
"""
if isinstance(config.activation_func, str):
from megatron.bridge.utils.activation_map import str_to_callable

config.activation_func = str_to_callable(config.activation_func)


@dataclass
class TransformerConfig(MCoreTransformerConfig):
"""Megatron Core TransformerConfig with deferred post-init.
Expand Down Expand Up @@ -87,8 +100,11 @@ def finalize(self) -> None:
to compute derived fields based on the current field values. It can be
called multiple times safely.
"""
_resolve_string_fields(self)
if self.pipeline_model_parallel_size > 1 and self.pipeline_dtype is None:
self.pipeline_dtype = self.params_dtype
if self.sequence_parallel and self.tensor_model_parallel_size <= 1:
self.sequence_parallel = False
MCoreTransformerConfig.__post_init__(self)

def __deepcopy__(self, memo):
Expand Down Expand Up @@ -152,8 +168,11 @@ def finalize(self) -> None:
to compute derived fields based on the current field values. It can be
called multiple times safely.
"""
_resolve_string_fields(self)
if self.pipeline_model_parallel_size > 1 and self.pipeline_dtype is None:
self.pipeline_dtype = self.params_dtype
if self.sequence_parallel and self.tensor_model_parallel_size <= 1:
self.sequence_parallel = False
MCoreMLATransformerConfig.__post_init__(self)


Expand Down Expand Up @@ -201,6 +220,9 @@ def finalize(self) -> None:
to compute derived fields and parse the heterogeneous block configurations.
It can be called multiple times safely.
"""
_resolve_string_fields(self)
if self.sequence_parallel and self.tensor_model_parallel_size <= 1:
self.sequence_parallel = False
MCoreHeterogeneousTransformerConfig.__post_init__(self)

def get_config_for_layer(self, layer_number: int) -> MCoreTransformerConfig:
Expand Down
7 changes: 7 additions & 0 deletions src/megatron/bridge/recipes/gpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .vanilla_gpt import vanilla_gpt_pretrain_config


__all__ = [
"vanilla_gpt_pretrain_config",
]
125 changes: 125 additions & 0 deletions src/megatron/bridge/recipes/gpt/vanilla_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Vanilla GPT recipe — a minimal baseline that mirrors Megatron-LM pretrain_gpt.py defaults.

Use this recipe for MLM <-> Bridge correlation testing. All architectural
and training knobs are left at their Megatron-Core / pretrain_gpt.py defaults
so that the *only* source of difference between the two frameworks is what you
explicitly override on the CLI.

Example
-------
torchrun --nproc_per_node=1 scripts/training/run_recipe.py \\
--recipe vanilla_gpt_pretrain_config \\
model.num_layers=2 model.hidden_size=256 model.num_attention_heads=4 \\
model.activation_func=silu model.gated_linear_unit=true \\
train.train_iters=10 train.global_batch_size=8 train.micro_batch_size=2
"""

import os

from megatron.core.distributed import DistributedDataParallelConfig

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE
from megatron.bridge.training.config import (
CheckpointConfig,
ConfigContainer,
DistributedInitConfig,
GPTDatasetConfig,
LoggerConfig,
RNGConfig,
TokenizerConfig,
TrainingConfig,
ValidationConfig,
)


def vanilla_gpt_pretrain_config() -> ConfigContainer:
"""Minimal GPT pretrain config aligned with Megatron-LM pretrain_gpt.py defaults.

The model provider uses bare GPTModelProvider defaults (LayerNorm, GeLU,
learned_absolute position embeddings, etc.) so there are **no** hidden
model-specific assumptions. Override anything you need via CLI, including
``model.activation_func=silu`` and ``model.gated_linear_unit=true`` for
SwiGLU activation.

Returns:
ConfigContainer with Megatron-LM-compatible defaults.
"""
base_output_dir = os.path.join(os.getcwd(), "nemo_experiments")
run_output_dir = os.path.join(base_output_dir, "default")
checkpoint_dir = os.path.join(run_output_dir, "checkpoints")
tensorboard_dir = os.path.join(run_output_dir, "tb_logs")

opt_config, scheduler = distributed_fused_adam_with_cosine_annealing(
lr_warmup_iters=500,
lr_decay_iters=None,
max_lr=3e-4,
min_lr=3e-5,
)

cfg = ConfigContainer(
# Bare GPTModelProvider — all fields at their dataclass defaults.
model=GPTModelProvider(),
train=TrainingConfig(
train_iters=300000,
global_batch_size=32,
micro_batch_size=2,
),
validation=ValidationConfig(
eval_interval=500,
eval_iters=32,
),
optimizer=opt_config,
scheduler=scheduler,
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
overlap_param_gather=False,
use_distributed_optimizer=False,
),
dataset=GPTDatasetConfig(
random_seed=1234,
sequence_length=1024,
blend=None,
blend_per_split=None,
split="9999,8,2",
dataloader_type="single",
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
),
logger=LoggerConfig(
log_interval=10,
tensorboard_dir=tensorboard_dir,
),
tokenizer=TokenizerConfig(
tokenizer_type="NullTokenizer",
vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE,
),
checkpoint=CheckpointConfig(
save_interval=500,
save=checkpoint_dir,
ckpt_format="torch_dist",
),
rng=RNGConfig(seed=1234),
dist=DistributedInitConfig(),
mixed_precision="bf16_mixed",
)

return cfg
17 changes: 17 additions & 0 deletions src/megatron/bridge/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,17 @@ class GPTDatasetConfig(MCoreGPTDatasetConfig, DataloaderConfig):
for field modifications after construction but before computed fields are calculated.
"""

data_path: str | list[str] | None = None
"""CLI-friendly alternative to ``blend``. Accepts a single path string,
a space-separated multi-path string, or a list of paths (with optional
interleaved weights, matching Megatron-LM ``--data-path`` semantics).
Converted to ``blend`` automatically during ``finalize()``."""

def __init__(
self,
seq_length: int | None = None,
skip_getting_attention_mask_from_dataset: bool = True,
data_path: str | list[str] | None = None,
*args,
**kwargs,
):
Expand All @@ -422,8 +429,10 @@ def __init__(
seq_length (int | None): the sequence length. If not provided, `sequence_length` must be in kwargs.
skip_getting_attention_mask_from_dataset (bool): if set, the dataset will pass a None attention mask
and the attention mask is autogenerated from the attn backend.
data_path: CLI-friendly data path(s). Converted to ``blend`` in ``finalize()``.
"""
self.skip_getting_attention_mask_from_dataset = skip_getting_attention_mask_from_dataset
self.data_path = data_path

if seq_length is not None:
kwargs["sequence_length"] = seq_length
Expand Down Expand Up @@ -456,6 +465,14 @@ def finalize(self) -> None:
This method calls the original Megatron Core GPTDatasetConfig.__post_init__()
and then performs Bridge-specific validation.
"""
if self.blend is None and self.data_path is not None:
from megatron.core.datasets.utils import get_blend_from_list

if isinstance(self.data_path, str):
paths = self.data_path.split()
else:
paths = list(self.data_path)
self.blend = get_blend_from_list(paths)

# Call MCore's post_init
super(MCoreGPTDatasetConfig, self).__post_init__()
Expand Down
23 changes: 20 additions & 3 deletions src/megatron/bridge/training/utils/omegaconf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from hydra.core.override_parser.overrides_parser import OverridesParser
from omegaconf import DictConfig, OmegaConf

# Re-export so existing callers (e.g. transformer_config.py) keep working.
from megatron.bridge.utils.activation_map import callable_to_str, str_to_callable # noqa: F401


logger = logging.getLogger(__name__)

Expand All @@ -36,6 +39,9 @@
# Sentinel object to distinguish between "exclude this field" and "field is legitimately None"
_EXCLUDE_FIELD = object()

# Fields whose callables should be serialized as strings (not excluded)
_SERIALIZABLE_CALLABLE_FIELDS: frozenset[str] = frozenset({"activation_func"})


def create_omegaconf_dict_config(config_container: Any) -> Tuple[DictConfig, Dict[str, Any]]:
"""Create OmegaConf while tracking excluded fields for later restoration.
Expand Down Expand Up @@ -260,8 +266,15 @@ def _dataclass_to_omegaconf_dict(val_to_convert: Any, path: str = "") -> Any:
logger.debug(f"Converting torch.dtype at {current_path}: {val_to_convert}")
return str(val_to_convert)

# Handle callables - exclude them completely
# Handle callables — serialize known activation functions as strings,
# exclude everything else.
if _is_omegaconf_problematic(val_to_convert):
field_name = current_path.rsplit(".", 1)[-1] if "." in current_path else current_path
if field_name in _SERIALIZABLE_CALLABLE_FIELDS:
str_name = callable_to_str(val_to_convert)
if str_name is not None:
logger.debug(f"Serializing callable at {current_path} as string: {str_name}")
return str_name
logger.debug(f"Excluding callable at {current_path}: {type(val_to_convert)} - {val_to_convert}")
return _EXCLUDE_FIELD

Expand Down Expand Up @@ -356,8 +369,12 @@ def _track_excluded_fields(obj: Any, path: str = "") -> Dict[str, Any]:
field_value = getattr(obj, field_name)

if _is_omegaconf_problematic(field_value):
excluded_fields[field_path] = field_value
logger.debug(f"Tracking excluded callable: {field_path}")
# Skip fields that are serialized as strings (not excluded)
if field_name in _SERIALIZABLE_CALLABLE_FIELDS and callable_to_str(field_value) is not None:
logger.debug(f"Skipping serializable callable (not excluded): {field_path}")
else:
excluded_fields[field_path] = field_value
logger.debug(f"Tracking excluded callable: {field_path}")
elif dataclasses.is_dataclass(field_value):
nested_excluded = _track_excluded_fields(field_value, field_path)
excluded_fields.update(nested_excluded)
Expand Down
Loading
Loading