Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 249 files
10 changes: 5 additions & 5 deletions src/megatron/bridge/models/mamba/mamba_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ class MambaModelConfig(ModelConfig):
on the embedded ``transformer`` config are accessible directly on this object
via ``__getattr__``/``__setattr__`` proxying.
Supports hybrid SSM/attention architectures via ``hybrid_attention_ratio``,
``hybrid_mlp_ratio``, and ``hybrid_override_pattern``.
Supports hybrid SSM/attention architectures via ``hybrid_layer_pattern``
Note:
``vocab_size`` must be set before passing this config to ``MambaModelBuilder``.
``hybrid_attention_ratio``,``hybrid_mlp_ratio``, and
``hybrid_override_pattern`` are deprecated and will be removed in a future release.
"""

builder: ClassVar[str] = "megatron.bridge.models.mamba.MambaModelBuilder"
Expand All @@ -108,6 +109,7 @@ class MambaModelConfig(ModelConfig):
hybrid_attention_ratio: float = 0.0
hybrid_mlp_ratio: float = 0.0
hybrid_override_pattern: str | None = None
hybrid_layer_pattern: str | None = None
seq_length: int = 8192
# Mamba with no attention has no need for position embeddings, so none is default
position_embedding_type: Literal["learned_absolute", "rope", "none"] = "none"
Expand Down Expand Up @@ -222,9 +224,7 @@ def build_model(
mamba_stack_spec=mamba_stack_spec,
vocab_size=padded_vocab_size,
max_sequence_length=self._model_config.seq_length,
hybrid_attention_ratio=self._model_config.hybrid_attention_ratio,
hybrid_mlp_ratio=self._model_config.hybrid_mlp_ratio,
hybrid_override_pattern=self._model_config.hybrid_override_pattern,
hybrid_layer_pattern=self._model_config.hybrid_layer_pattern,
fp16_lm_cross_entropy=self._model_config.fp16_lm_cross_entropy,
parallel_output=self._model_config.parallel_output,
share_embeddings_and_output_weights=self._model_config.share_embeddings_and_output_weights,
Expand Down
52 changes: 48 additions & 4 deletions src/megatron/bridge/models/mamba/mamba_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import warnings
from dataclasses import dataclass
from typing import Callable, Literal, Optional, Union

Expand All @@ -22,11 +23,13 @@
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.ssm.mamba_hybrid_layer_allocation import get_hybrid_total_layer_count
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.enums import AttnBackend

from megatron.bridge.models.model_provider import ModelProviderMixin
from megatron.bridge.models.transformer_config import TransformerConfig
from megatron.bridge.utils.common_utils import get_rank_safe
from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size


Expand Down Expand Up @@ -91,12 +94,13 @@ class MambaModelProvider(TransformerConfig, ModelProviderMixin[MCoreMambaModel])
params_dtype: torch.dtype = torch.bfloat16
fp16: bool = False
bf16: bool = True
num_layers: int = 2
num_layers: int = None
mamba_num_groups: int = 8
num_attention_heads: int = 1
hybrid_attention_ratio: float = 0.0
hybrid_mlp_ratio: float = 0.0
hybrid_override_pattern: Optional[str] = None
hybrid_layer_pattern: Optional[str] = None
seq_length: int = 8192
# Mamba with no attention has no need for position embeddings, so none is default
position_embedding_type: Literal["learned_absolute", "rope", "none"] = "none"
Expand Down Expand Up @@ -127,6 +131,48 @@ class MambaModelProvider(TransformerConfig, ModelProviderMixin[MCoreMambaModel])
# If True, restore the modelopt_state that contains quantization, sparsity, speculative decoding transformation state.
restore_modelopt_state: bool = False

def finalize(self) -> None:
"""Finalize the Mamba model provider.
Calculates the number of layers from the hybrid_layer_pattern.
Executes the deferred MCore post-init logic.
"""
# Check if hybrid_override_pattern is specified and throw deprecation warning
used_hybrid_override_pattern = False
if self.hybrid_override_pattern is not None:
assert self.hybrid_layer_pattern is None, (
"hybrid_override_pattern and hybrid_layer_pattern cannot both be specified. "
"hybrid_override_pattern is deprecated; use hybrid_layer_pattern instead."
)
if get_rank_safe() == 0:
warnings.warn(
"hybrid_override_pattern is deprecated. Use hybrid_layer_pattern instead.",
DeprecationWarning,
stacklevel=2,
)
self.hybrid_layer_pattern = self.hybrid_override_pattern
used_hybrid_override_pattern = True

# Check if hybrid_layer_pattern is specified and derive num_layers from pattern
if self.hybrid_layer_pattern is not None:
# Derive num_layers from pattern
num_layers_in_pattern = get_hybrid_total_layer_count(self.hybrid_layer_pattern)
if self.num_layers is not None:
if used_hybrid_override_pattern:
assert self.num_layers == num_layers_in_pattern, (
f"num_layers ({self.num_layers}) does not match the number of layers "
f"derived from hybrid_override_pattern ({num_layers_in_pattern}). "
f"Please correct num_layers or the pattern."
)
else:
assert self.num_layers == num_layers_in_pattern, (
f"num_layers ({self.num_layers}) does not match the number of layers "
f"derived from hybrid_layer_pattern ({num_layers_in_pattern}). "
f"Please correct num_layers or the pattern."
)
self.num_layers = num_layers_in_pattern

super().finalize()

def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreMambaModel:
"""Configure and instantiate a Megatron Core Mamba model based on this configuration.
Expand Down Expand Up @@ -166,9 +212,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreMa
mamba_stack_spec=mamba_stack_spec,
vocab_size=padded_vocab_size,
max_sequence_length=self.seq_length,
hybrid_attention_ratio=self.hybrid_attention_ratio,
hybrid_mlp_ratio=self.hybrid_mlp_ratio,
hybrid_override_pattern=self.hybrid_override_pattern,
hybrid_layer_pattern=self.hybrid_layer_pattern,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
Expand Down
9 changes: 9 additions & 0 deletions src/megatron/bridge/models/nemotron_vl/nemotron_vl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
class NemotronVLBridge(MegatronModelBridge):
"""Conversion utilities between HF Nemotron-VL and Megatron-Core format."""

# Extend CONFIG_MAPPING with Nemotron-VL specific fields
CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [
# Mamba-specific fields
("hybrid_override_pattern", "hybrid_layer_pattern"),
]

# ------------------------------------------------------------------
# Provider translation
# ------------------------------------------------------------------
Expand All @@ -49,6 +55,9 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> NemotronNano12Bv2VLMo
# Use base class helper for common config mapping
provider_kwargs = self.hf_config_to_provider_kwargs(llm_config)

# Remove num_layers from provider as it is derived from hybrid_layer_pattern
provider_kwargs["num_layers"] = None

# Handle vocab size divisibility
provider_kwargs["make_vocab_size_divisible_by"] = self.make_vocab_size_divisible_by(llm_config.vocab_size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None): # noqa:
img_h=512,
img_w=512,
patch_dim=16,
hybrid_attention_ratio=0.0,
hybrid_mlp_ratio=0.0,
hybrid_override_pattern=self.hybrid_override_pattern,
hybrid_layer_pattern=self.hybrid_layer_pattern,
image_token_index=131072,
pixel_shuffle=True,
max_num_tiles=12,
Expand Down
7 changes: 5 additions & 2 deletions src/megatron/bridge/models/nemotronh/nemotron_h_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class NemotronHBridge(MegatronModelBridge):
("mamba_num_heads", "mamba_num_heads"),
("n_groups", "mamba_num_groups"),
("ssm_state_size", "mamba_state_dim"),
("hybrid_override_pattern", "hybrid_override_pattern"),
("hybrid_override_pattern", "hybrid_layer_pattern"),
("residual_in_fp32", "fp32_residual_connection"),
("use_bias", "add_bias_linear"),
("layer_norm_epsilon", "layernorm_epsilon"),
Expand All @@ -78,6 +78,9 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MambaModelProvid
provider = super().provider_bridge(hf_pretrained)
hf_config = hf_pretrained.config

# Remove num_layers from provider as it is derived from hybrid_layer_pattern
provider.num_layers = None

# Nemotron-H specific defaults
provider.activation_func = squared_relu
provider.masked_softmax_fusion = True
Expand All @@ -88,7 +91,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MambaModelProvid
provider.is_hybrid_model = True

# MoE-specific defaults (only if MoE is enabled)
if hasattr(hf_config, "n_routed_experts") and hf_config.n_routed_experts > 0:
if hasattr(hf_config, "n_routed_experts") and hf_config.n_routed_experts is not None:
provider.moe_aux_loss_coeff = 0.0001
provider.moe_router_score_function = "sigmoid"
provider.moe_router_enable_expert_bias = True
Expand Down
23 changes: 9 additions & 14 deletions src/megatron/bridge/models/nemotronh/nemotron_h_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ class NemotronHModelProvider(MambaModelProvider):
moe_permute_fusion: bool = True
moe_shared_expert_overlap: bool = True

# Num layers i


@dataclass
class NemotronHModelProvider4B(NemotronHModelProvider):
"""Configuration for a 4B parameter Nemotron-H model."""

hybrid_override_pattern: str = "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"
num_layers: int = 52
hybrid_layer_pattern: str = "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Deprecation is currently a hard break for provider constructor kwargs.

These renames remove acceptance of legacy kwargs at dataclass init time. Any caller still passing hybrid_override_pattern (or the other deprecated hybrid knobs mentioned in the PR) will now fail with TypeError instead of getting a deprecation warning.

Proposed compatibility bridge for deprecated config fields
 `@dataclass`
 class NemotronHModelProvider(MambaModelProvider):
     """Configuration for Nemotron-H models."""
+    hybrid_layer_pattern: str | None = None
+    hybrid_override_pattern: str | None = None
+    hybrid_attention_ratio: float | None = None
+    hybrid_mlp_ratio: float | None = None
+
+    def __post_init__(self) -> None:
+        if self.hybrid_override_pattern is not None:
+            warnings.warn(
+                "hybrid_override_pattern is deprecated; use hybrid_layer_pattern.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            if self.hybrid_layer_pattern is not None and self.hybrid_layer_pattern != self.hybrid_override_pattern:
+                raise ValueError(
+                    "Both hybrid_layer_pattern and hybrid_override_pattern were provided with different values."
+                )
+            self.hybrid_layer_pattern = self.hybrid_override_pattern
+
+        if self.hybrid_attention_ratio is not None or self.hybrid_mlp_ratio is not None:
+            warnings.warn(
+                "hybrid_attention_ratio and hybrid_mlp_ratio are deprecated; use hybrid_layer_pattern.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+
+        super().__post_init__()

Also applies to: 78-78, 91-91, 106-106, 124-124, 140-140, 158-158

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/nemotronh/nemotron_h_provider.py` at line 63, The
dataclass now uses renamed fields like hybrid_layer_pattern which breaks callers
passing legacy kwargs (e.g., hybrid_override_pattern); add
backward-compatibility by accepting the old deprecated kwargs as optional fields
(e.g., hybrid_override_pattern: Optional[str] = None) and in the class
__post_init__ map any deprecated field values onto the new fields (set
hybrid_layer_pattern from hybrid_override_pattern when present), emit a
DeprecationWarning via warnings.warn, and repeat this pattern for the other
renamed hybrid knobs (the fields referenced at the other commented locations) so
legacy callers do not get a TypeError.

hidden_size: int = 3072
mamba_num_heads: int = 112
kv_channels: int = 128
Expand All @@ -75,8 +76,7 @@ class NemotronHModelProvider4B(NemotronHModelProvider):
class NemotronHModelProvider8B(NemotronHModelProvider):
"""Configuration for a 8B parameter Nemotron-H model."""

hybrid_override_pattern: str = "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"
num_layers: int = 52
hybrid_layer_pattern: str = "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"
hidden_size: int = 4096
mamba_state_dim: int = 128
mamba_num_heads: int = 128
Expand All @@ -88,10 +88,9 @@ class NemotronHModelProvider8B(NemotronHModelProvider):
class NemotronHModelProvider47B(NemotronHModelProvider):
"""Configuration for a 47B parameter Nemotron-H model."""

hybrid_override_pattern: str = (
hybrid_layer_pattern: str = (
"M-M-M-M-M-M-M-M-M*-M-M-M-M-M-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-M-M---MM---M-M*-M-M-M-M-M-"
)
num_layers: int = 98
hidden_size: int = 8192
mamba_state_dim: int = 256
mamba_num_heads: int = 256
Expand All @@ -103,11 +102,10 @@ class NemotronHModelProvider47B(NemotronHModelProvider):
class NemotronHModelProvider56B(NemotronHModelProvider):
"""Configuration for a 56B parameter Nemotron-H model."""

hybrid_override_pattern: str = (
hybrid_layer_pattern: str = (
"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-"
"M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"
)
num_layers: int = 118
hidden_size: int = 8192
mamba_state_dim: int = 256
mamba_num_heads: int = 256
Expand All @@ -121,8 +119,7 @@ class NemotronHModelProvider56B(NemotronHModelProvider):
class NemotronNanoModelProvider9Bv2(NemotronHModelProvider):
"""Configuration for a 9B parameter Nemotron Nano v2 model."""

hybrid_override_pattern: str = "M-M-M-MM-M-M-M*-M-M-M*-M-M-M-M*-M-M-M-M*-M-MM-M-M-M-M-M-"
num_layers: int = 56
hybrid_layer_pattern: str = "M-M-M-MM-M-M-M*-M-M-M*-M-M-M-M*-M-M-M-M*-M-MM-M-M-M-M-M-"
hidden_size: int = 4480
mamba_num_heads: int = 128
kv_channels: int = 128
Expand All @@ -137,8 +134,7 @@ class NemotronNanoModelProvider9Bv2(NemotronHModelProvider):
class NemotronNanoModelProvider12Bv2(NemotronHModelProvider):
"""Configuration for the Nemotron Nano v2 12B model."""

hybrid_override_pattern: str = "M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M-"
num_layers: int = 62
hybrid_layer_pattern: str = "M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M*-M-M-M-M-"
hidden_size: int = 5120
mamba_num_heads: int = 128
kv_channels: int = 128
Expand All @@ -155,8 +151,7 @@ class Nemotron3NanoProvider(NemotronHModelProvider):

seq_length: int = 262144
num_query_groups: int = 2
hybrid_override_pattern: str = "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME"
num_layers: int = 52
hybrid_layer_pattern: str = "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME"
hidden_size: int = 2688
mamba_num_heads: int = 64
kv_channels: int = 128
Expand Down
4 changes: 1 addition & 3 deletions src/megatron/bridge/training/mlm_compat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ def _mamba_provider(
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
hybrid_attention_ratio=args.hybrid_attention_ratio,
hybrid_mlp_ratio=args.hybrid_mlp_ratio,
hybrid_override_pattern=args.hybrid_override_pattern,
hybrid_layer_pattern=args.hybrid_layer_pattern,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
Expand Down
8 changes: 4 additions & 4 deletions src/megatron/bridge/training/utils/flop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1):

def calculate_layer_counts():
"""Calculate the number of attention, Mamba, MLP, and MoE layers."""
if hasattr(cfg.model, "hybrid_override_pattern") and cfg.model.hybrid_override_pattern:
if hasattr(cfg.model, "hybrid_layer_pattern") and cfg.model.hybrid_layer_pattern:
counts = {"M": 0, "*": 0, "-": 0, "E": 0}
try:
parse_hybrid_pattern = importlib.import_module(
"megatron.core.ssm.mamba_hybrid_layer_allocation"
).parse_hybrid_pattern
parsed = parse_hybrid_pattern(cfg.model.hybrid_override_pattern)
parsed = parse_hybrid_pattern(cfg.model.hybrid_layer_pattern)
if parsed.main_pattern:
for layer_type in parsed.main_pattern:
if layer_type in counts:
Expand All @@ -52,7 +52,7 @@ def calculate_layer_counts():
if layer_type in counts:
counts[layer_type] += parsed.mtp_num_depths
except (ImportError, ModuleNotFoundError):
for layer_type in cfg.model.hybrid_override_pattern:
for layer_type in cfg.model.hybrid_layer_pattern:
if layer_type in counts:
counts[layer_type] += 1
return counts["*"], counts["M"], counts["-"], counts["E"]
Expand Down Expand Up @@ -431,7 +431,7 @@ def transformer_flops():
mtp_num_layers = getattr(cfg.model, "mtp_num_layers", None)
if mtp_num_layers is None:
# When using unified hybrid patterns, infer MTP depth count from the pattern.
hybrid_pattern = getattr(cfg.model, "hybrid_override_pattern", None)
hybrid_pattern = getattr(cfg.model, "hybrid_layer_pattern", None)
if hybrid_pattern:
try:
parse_hybrid_pattern = importlib.import_module(
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/bridge/training/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def training_log(
track_names.append("z_loss")

if config.model.is_hybrid_model:
layers = config.model.hybrid_override_pattern.count("E")
layers = config.model.hybrid_layer_pattern.count("E")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard hybrid pattern access during deprecation window.

Line 627 assumes hybrid_layer_pattern is always populated. Hybrid configs that still only set the deprecated key can fail in logging.

Proposed fix
-        if config.model.is_hybrid_model:
-            layers = config.model.hybrid_layer_pattern.count("E")
+        if config.model.is_hybrid_model:
+            hybrid_pattern = getattr(config.model, "hybrid_layer_pattern", None) or getattr(
+                config.model, "hybrid_override_pattern", None
+            )
+            layers = hybrid_pattern.count("E") if hybrid_pattern else config.model.num_layers
         else:
             layers = config.model.num_layers
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
layers = config.model.hybrid_layer_pattern.count("E")
if config.model.is_hybrid_model:
hybrid_pattern = getattr(config.model, "hybrid_layer_pattern", None) or getattr(
config.model, "hybrid_override_pattern", None
)
layers = hybrid_pattern.count("E") if hybrid_pattern else config.model.num_layers
else:
layers = config.model.num_layers
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/utils/train_utils.py` at line 627, The current
assignment to layers assumes config.model.hybrid_layer_pattern exists; change it
to safely access the attribute (e.g., pattern = getattr(config.model,
"hybrid_layer_pattern", "") or check hasattr(config.model,
"hybrid_layer_pattern") and fall back to an empty string or the deprecated key
if present) and then compute layers = pattern.count("E") so logging won't fail
for configs still using the deprecated key.

else:
layers = config.model.num_layers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,24 @@ def test_nemotron_vl_conversion_parallelism(self, nemotron_vl_toy_model_path, tm
test_output_dir = tmp_path / f"nemotron_vl_{test_name}"
test_output_dir.mkdir(exist_ok=True)

# Modify config.json to add | separator for hybrid_override_pattern to be able to run PP > 1
config_file = Path(nemotron_vl_toy_model_path) / "config.json"
assert config_file.exists(), f"config.json not found at {config_file}"
with open(config_file) as f:
config_data = json.load(f)

if pp > 1:
config_data["hybrid_override_pattern"] = (
HF_NEMOTRON_VL_TOY_MODEL_OVERRIDES["hybrid_override_pattern"][:2]
+ "|"
+ HF_NEMOTRON_VL_TOY_MODEL_OVERRIDES["hybrid_override_pattern"][2:]
)
else:
config_data["hybrid_override_pattern"] = HF_NEMOTRON_VL_TOY_MODEL_OVERRIDES["hybrid_override_pattern"]

with open(config_file, "w") as f:
json.dump(config_data, f, indent=2)

# Run hf_megatron_roundtrip_multi_gpu.py with specified parallelism configuration on our toy model
cmd = [
"python",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_toy_model_creation(self, nemotronh_toy_model_path):
"tp,pp,test_name",
[
(2, 1, "TP"),
(1, 2, "PP"),
pytest.param(1, 2, "PP", marks=pytest.mark.skip(reason="Skipping until a better resolution for | pattern is found")),
],
)
def test_nemotronh_conversion_parallelism(self, nemotronh_toy_model_path, tmp_path, tp, pp, test_name):
Expand All @@ -209,6 +209,24 @@ def test_nemotronh_conversion_parallelism(self, nemotronh_toy_model_path, tmp_pa
test_output_dir = tmp_path / f"nemotronh_{test_name}"
test_output_dir.mkdir(exist_ok=True)

# Modify config.json to add | separator for hybrid_override_pattern to be able to run PP > 1
config_file = Path(nemotronh_toy_model_path) / "config.json"
assert config_file.exists(), f"config.json not found at {config_file}"
with open(config_file) as f:
config_data = json.load(f)

if pp > 1:
config_data["hybrid_override_pattern"] = (
HF_NEMOTRONH_TOY_MODEL_OVERRIDES["hybrid_override_pattern"][:2]
+ "|"
+ HF_NEMOTRONH_TOY_MODEL_OVERRIDES["hybrid_override_pattern"][2:]
)
else:
config_data["hybrid_override_pattern"] = HF_NEMOTRONH_TOY_MODEL_OVERRIDES["hybrid_override_pattern"]

with open(config_file, "w") as f:
json.dump(config_data, f, indent=2)

# Run hf_megatron_roundtrip_multi_gpu.py with specified parallelism configuration on our toy model
cmd = [
"python",
Expand Down
Loading