Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions engram_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

from model_provider import count_parameters_in_layer
from megatron.core.models.engram import EngramGPTModel
from megatron.core.models.engram.engram_module import EngramConfig, NgramHashMapping
from megatron.core.models.engram.engram_layer_specs import get_engram_layer_local_spec
from megatron.core.transformer import TransformerConfig
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args


def _build_engram_config(args) -> EngramConfig:
"""Build EngramConfig from training arguments."""
engram_layer_ids = getattr(args, 'engram_layer_ids', [1, 15])
if isinstance(engram_layer_ids, str):
engram_layer_ids = [int(x) for x in engram_layer_ids.split(',')]

return EngramConfig(
engram_vocab_size=getattr(
args, 'engram_vocab_size', [129280 * 5, 129280 * 5]
),
max_ngram_size=getattr(args, 'engram_max_ngram_size', 3),
n_embed_per_ngram=getattr(args, 'engram_n_embed_per_ngram', 512),
n_head_per_ngram=getattr(args, 'engram_n_head_per_ngram', 8),
engram_layer_ids=engram_layer_ids,
pad_id=getattr(args, 'engram_pad_id', 2),
seed=getattr(args, 'engram_seed', 0),
kernel_size=getattr(args, 'engram_kernel_size', 4),
hc_mult=getattr(args, 'engram_hc_mult', 4),
tokenizer_name_or_path=getattr(
args, 'engram_tokenizer', 'deepseek-ai/DeepSeek-V3'
),
)


def engram_builder(
args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None
):
"""Build an Engram-augmented GPT model.

Constructs the EngramConfig, pre-computes the n-gram vocabulary sizes via
NgramHashMapping, builds the layer spec with Engram support, and returns
an EngramGPTModel instance.
"""
print_rank_0('building Engram GPT model ...')
if config is None:
config = core_transformer_config_from_args(args, TransformerConfig)
assert args.use_legacy_models is False, "Engram only supported in Mcore!"

engram_config = _build_engram_config(args)

# Pre-compute vocab sizes for the hash embedding tables.
# NgramHashMapping also initializes the CompressedTokenizer, which is
# moderately expensive, so we do it once here and pass the results through.
ngram_hash_mapping = NgramHashMapping(
engram_vocab_size=engram_config.engram_vocab_size,
max_ngram_size=engram_config.max_ngram_size,
n_embed_per_ngram=engram_config.n_embed_per_ngram,
n_head_per_ngram=engram_config.n_head_per_ngram,
layer_ids=engram_config.engram_layer_ids,
tokenizer_name_or_path=engram_config.tokenizer_name_or_path,
pad_id=engram_config.pad_id,
seed=engram_config.seed,
)
vocab_size_across_layers = ngram_hash_mapping.vocab_size_across_layers

transformer_layer_spec = get_engram_layer_local_spec(
engram_config=engram_config,
vocab_size_across_layers=vocab_size_across_layers,
num_experts=getattr(args, 'num_experts', None),
moe_grouped_gemm=getattr(args, 'moe_grouped_gemm', False),
qk_layernorm=getattr(args, 'qk_layernorm', False),
normalization=getattr(args, 'normalization', None),
qk_l2_norm=getattr(args, 'qk_l2_norm', False),
)

model = EngramGPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
engram_config=engram_config,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
vp_stage=vp_stage,
pg_collection=pg_collection,
)

for l in range(model.decoder.num_layers_per_pipeline_rank):
layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.')
print_rank_0(f" == params layer {l}: {layer_params}")

return model
5 changes: 5 additions & 0 deletions megatron/core/models/engram/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

from megatron.core.models.engram.engram_model import EngramGPTModel

__all__ = ["EngramGPTModel"]
89 changes: 89 additions & 0 deletions megatron/core/models/engram/engram_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

"""
Engram-enabled Transformer Layer.

Extends the standard TransformerLayer to inject an Engram module that runs
before the self-attention computation. The engram output is added as a residual
to the hidden states before they enter the standard attention + MLP pipeline.
"""

from __future__ import annotations

from typing import Any, Optional

from torch import Tensor

from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)

from megatron.core.models.engram.engram_module import EngramConfig, EngramModule


class EngramTransformerLayer(TransformerLayer):
"""A transformer layer augmented with an Engram module.

The Engram module is applied as a residual before the standard self-attention
computation. The pre-computed engram embeddings must be set via the
``engram.precompute_embeddings()`` method before each forward pass (handled
by ``EngramGPTModel``).

For layers that are not in the engram_layer_ids list, this behaves identically
to a standard TransformerLayer.
"""

def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: Optional[float] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
is_mtp_layer: bool = False,
add_layer_offset: bool = True,
pp_layer_offset: Optional[int] = None,
engram_config: Optional[EngramConfig] = None,
engram_vocab_size_across_layers: Optional[dict] = None,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
hidden_dropout=hidden_dropout,
pg_collection=pg_collection,
vp_stage=vp_stage,
is_mtp_layer=is_mtp_layer,
add_layer_offset=add_layer_offset,
pp_layer_offset=pp_layer_offset,
)

self.engram: Optional[EngramModule] = None
if (
engram_config is not None
and engram_vocab_size_across_layers is not None
and self.layer_number in engram_config.engram_layer_ids
and self.layer_number in engram_vocab_size_across_layers
):
self.engram = EngramModule(
layer_id=self.layer_number,
hidden_size=config.hidden_size,
engram_config=engram_config,
vocab_size_for_layer=engram_vocab_size_across_layers[self.layer_number],
)

def forward(self, hidden_states: Tensor, *args: Any, **kwargs: Any):
"""Forward pass with optional Engram injection before attention.

The Engram output is added as a residual to hidden_states before the
standard TransformerLayer forward (layernorm → attention → MLP).
"""
if self.engram is not None:
engram_output = self.engram(hidden_states)
hidden_states = engram_output + hidden_states

return super().forward(hidden_states, *args, **kwargs)
120 changes: 120 additions & 0 deletions megatron/core/models/engram/engram_layer_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

"""
Layer specifications for the Engram-augmented GPT model.
Provides factory functions that produce ModuleSpec objects pointing to
EngramTransformerLayer (instead of the standard TransformerLayer) while
reusing the standard GPT submodule wiring for attention and MLP.
"""

from __future__ import annotations

from typing import Dict, List, Optional

from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.backends import LocalSpecProvider
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules

from megatron.core.models.engram.engram_layer import EngramTransformerLayer
from megatron.core.models.engram.engram_module import EngramConfig

try:
import apex # type: ignore
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
LNImpl = WrappedTorchNorm


def get_engram_layer_local_spec(
engram_config: EngramConfig,
vocab_size_across_layers: Dict[int, List[List[int]]],
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
normalization: Optional[str] = None,
qk_l2_norm: Optional[bool] = False,
) -> ModuleSpec:
"""Build a ModuleSpec for EngramTransformerLayer using Megatron-Core-only modules.
This mirrors ``get_gpt_layer_local_spec`` but substitutes EngramTransformerLayer
and passes the engram configuration via the spec's ``params`` dict so that
``build_module`` can forward them to the layer constructor.
Args:
engram_config: Engram-specific configuration.
vocab_size_across_layers: Mapping from layer_id to per-ngram-level head
vocab sizes, as computed by NgramHashMapping.
Other args match those of ``get_gpt_layer_local_spec``.
Returns:
ModuleSpec targeting EngramTransformerLayer.
"""
backend = LocalSpecProvider()

if normalization == "RMSNorm":
layer_norm = backend.layer_norm(rms_norm=True, for_qk=False)
qk_norm = backend.layer_norm(rms_norm=True, for_qk=True)
else:
layer_norm = backend.layer_norm(rms_norm=False, for_qk=False)
qk_norm = backend.layer_norm(rms_norm=False, for_qk=True)

mlp = _get_mlp_spec(backend, num_experts, moe_grouped_gemm)

submodules = TransformerLayerSubmodules(
input_layernorm=layer_norm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=backend.column_parallel_linear(),
core_attention=backend.core_attention(),
linear_proj=backend.row_parallel_linear(),
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=layer_norm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
)

return ModuleSpec(
module=EngramTransformerLayer,
submodules=submodules,
params={
"engram_config": engram_config,
"engram_vocab_size_across_layers": vocab_size_across_layers,
},
)


def _get_mlp_spec(
backend: LocalSpecProvider,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Build a dense MLP spec using the backend provider."""
if num_experts is None:
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=backend.column_parallel_linear(),
linear_fc2=backend.row_parallel_linear(),
),
)
else:
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_for_backend
return get_mlp_module_spec_for_backend(
backend=backend,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
)
Loading