From 56c589a12e1954d300c439995aafcf4f944d9b50 Mon Sep 17 00:00:00 2001 From: ilml Date: Wed, 4 Mar 2026 05:07:31 +0000 Subject: [PATCH 1/2] Add Engram model structure integration (v1) Integrate DeepSeek's Engram n-gram hash embedding module into Megatron-LM. This initial version focuses on model structure only, extending GPTModel with Engram-augmented transformer layers that inject gated n-gram embeddings before self-attention at configurable layer positions. Key components: - EngramModule: n-gram hashing, multi-head embedding, gated value projection, causal short convolution with hyper-connection multiplier - EngramTransformerLayer: extends TransformerLayer with pre-attention Engram - EngramGPTModel: extends GPTModel with hash pre-computation from input_ids - Layer specs, builder, and pretrain entry point following Mcore patterns Made-with: Cursor --- engram_builders.py | 100 ++++ megatron/core/models/engram/__init__.py | 5 + megatron/core/models/engram/engram_layer.py | 89 +++ .../core/models/engram/engram_layer_specs.py | 120 ++++ megatron/core/models/engram/engram_model.py | 161 ++++++ megatron/core/models/engram/engram_module.py | 511 ++++++++++++++++++ pretrain_engram.py | 295 ++++++++++ 7 files changed, 1281 insertions(+) create mode 100644 engram_builders.py create mode 100644 megatron/core/models/engram/__init__.py create mode 100644 megatron/core/models/engram/engram_layer.py create mode 100644 megatron/core/models/engram/engram_layer_specs.py create mode 100644 megatron/core/models/engram/engram_model.py create mode 100644 megatron/core/models/engram/engram_module.py create mode 100644 pretrain_engram.py diff --git a/engram_builders.py b/engram_builders.py new file mode 100644 index 00000000000..f929aefb584 --- /dev/null +++ b/engram_builders.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from model_provider import count_parameters_in_layer +from megatron.core.models.engram import EngramGPTModel +from megatron.core.models.engram.engram_module import EngramConfig, NgramHashMapping +from megatron.core.models.engram.engram_layer_specs import get_engram_layer_local_spec +from megatron.core.transformer import TransformerConfig +from megatron.training import get_args, print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args + + +def _build_engram_config(args) -> EngramConfig: + """Build EngramConfig from training arguments.""" + engram_layer_ids = getattr(args, 'engram_layer_ids', [1, 15]) + if isinstance(engram_layer_ids, str): + engram_layer_ids = [int(x) for x in engram_layer_ids.split(',')] + + return EngramConfig( + engram_vocab_size=getattr( + args, 'engram_vocab_size', [129280 * 5, 129280 * 5] + ), + max_ngram_size=getattr(args, 'engram_max_ngram_size', 3), + n_embed_per_ngram=getattr(args, 'engram_n_embed_per_ngram', 512), + n_head_per_ngram=getattr(args, 'engram_n_head_per_ngram', 8), + engram_layer_ids=engram_layer_ids, + pad_id=getattr(args, 'engram_pad_id', 2), + seed=getattr(args, 'engram_seed', 0), + kernel_size=getattr(args, 'engram_kernel_size', 4), + hc_mult=getattr(args, 'engram_hc_mult', 4), + tokenizer_name_or_path=getattr( + args, 'engram_tokenizer', 'deepseek-ai/DeepSeek-V3' + ), + ) + + +def engram_builder( + args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None +): + """Build an Engram-augmented GPT model. + + Constructs the EngramConfig, pre-computes the n-gram vocabulary sizes via + NgramHashMapping, builds the layer spec with Engram support, and returns + an EngramGPTModel instance. + """ + print_rank_0('building Engram GPT model ...') + if config is None: + config = core_transformer_config_from_args(args, TransformerConfig) + assert args.use_legacy_models is False, "Engram only supported in Mcore!" + + engram_config = _build_engram_config(args) + + # Pre-compute vocab sizes for the hash embedding tables. + # NgramHashMapping also initializes the CompressedTokenizer, which is + # moderately expensive, so we do it once here and pass the results through. + ngram_hash_mapping = NgramHashMapping( + engram_vocab_size=engram_config.engram_vocab_size, + max_ngram_size=engram_config.max_ngram_size, + n_embed_per_ngram=engram_config.n_embed_per_ngram, + n_head_per_ngram=engram_config.n_head_per_ngram, + layer_ids=engram_config.engram_layer_ids, + tokenizer_name_or_path=engram_config.tokenizer_name_or_path, + pad_id=engram_config.pad_id, + seed=engram_config.seed, + ) + vocab_size_across_layers = ngram_hash_mapping.vocab_size_across_layers + + transformer_layer_spec = get_engram_layer_local_spec( + engram_config=engram_config, + vocab_size_across_layers=vocab_size_across_layers, + num_experts=getattr(args, 'num_experts', None), + moe_grouped_gemm=getattr(args, 'moe_grouped_gemm', False), + qk_layernorm=getattr(args, 'qk_layernorm', False), + normalization=getattr(args, 'normalization', None), + qk_l2_norm=getattr(args, 'qk_l2_norm', False), + ) + + model = EngramGPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + engram_config=engram_config, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + vp_stage=vp_stage, + pg_collection=pg_collection, + ) + + for l in range(model.decoder.num_layers_per_pipeline_rank): + layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') + print_rank_0(f" == params layer {l}: {layer_params}") + + return model diff --git a/megatron/core/models/engram/__init__.py b/megatron/core/models/engram/__init__.py new file mode 100644 index 00000000000..83f2a3e5783 --- /dev/null +++ b/megatron/core/models/engram/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.models.engram.engram_model import EngramGPTModel + +__all__ = ["EngramGPTModel"] diff --git a/megatron/core/models/engram/engram_layer.py b/megatron/core/models/engram/engram_layer.py new file mode 100644 index 00000000000..ac42f701908 --- /dev/null +++ b/megatron/core/models/engram/engram_layer.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Engram-enabled Transformer Layer. + +Extends the standard TransformerLayer to inject an Engram module that runs +before the self-attention computation. The engram output is added as a residual +to the hidden states before they enter the standard attention + MLP pipeline. +""" + +from __future__ import annotations + +from typing import Any, Optional + +from torch import Tensor + +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import ( + TransformerLayer, + TransformerLayerSubmodules, +) + +from megatron.core.models.engram.engram_module import EngramConfig, EngramModule + + +class EngramTransformerLayer(TransformerLayer): + """A transformer layer augmented with an Engram module. + + The Engram module is applied as a residual before the standard self-attention + computation. The pre-computed engram embeddings must be set via the + ``engram.precompute_embeddings()`` method before each forward pass (handled + by ``EngramGPTModel``). + + For layers that are not in the engram_layer_ids list, this behaves identically + to a standard TransformerLayer. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: Optional[float] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + is_mtp_layer: bool = False, + add_layer_offset: bool = True, + pp_layer_offset: Optional[int] = None, + engram_config: Optional[EngramConfig] = None, + engram_vocab_size_across_layers: Optional[dict] = None, + ): + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + hidden_dropout=hidden_dropout, + pg_collection=pg_collection, + vp_stage=vp_stage, + is_mtp_layer=is_mtp_layer, + add_layer_offset=add_layer_offset, + pp_layer_offset=pp_layer_offset, + ) + + self.engram: Optional[EngramModule] = None + if ( + engram_config is not None + and engram_vocab_size_across_layers is not None + and self.layer_number in engram_config.engram_layer_ids + and self.layer_number in engram_vocab_size_across_layers + ): + self.engram = EngramModule( + layer_id=self.layer_number, + hidden_size=config.hidden_size, + engram_config=engram_config, + vocab_size_for_layer=engram_vocab_size_across_layers[self.layer_number], + ) + + def forward(self, hidden_states: Tensor, *args: Any, **kwargs: Any): + """Forward pass with optional Engram injection before attention. + + The Engram output is added as a residual to hidden_states before the + standard TransformerLayer forward (layernorm → attention → MLP). + """ + if self.engram is not None: + engram_output = self.engram(hidden_states) + hidden_states = engram_output + hidden_states + + return super().forward(hidden_states, *args, **kwargs) diff --git a/megatron/core/models/engram/engram_layer_specs.py b/megatron/core/models/engram/engram_layer_specs.py new file mode 100644 index 00000000000..53497f93a69 --- /dev/null +++ b/megatron/core/models/engram/engram_layer_specs.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Layer specifications for the Engram-augmented GPT model. + +Provides factory functions that produce ModuleSpec objects pointing to +EngramTransformerLayer (instead of the standard TransformerLayer) while +reusing the standard GPT submodule wiring for attention and MLP. +""" + +from __future__ import annotations + +from typing import Dict, List, Optional + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.backends import LocalSpecProvider +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules + +from megatron.core.models.engram.engram_layer import EngramTransformerLayer +from megatron.core.models.engram.engram_module import EngramConfig + +try: + import apex # type: ignore + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + LNImpl = FusedLayerNorm +except ImportError: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + LNImpl = WrappedTorchNorm + + +def get_engram_layer_local_spec( + engram_config: EngramConfig, + vocab_size_across_layers: Dict[int, List[List[int]]], + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + qk_layernorm: Optional[bool] = False, + normalization: Optional[str] = None, + qk_l2_norm: Optional[bool] = False, +) -> ModuleSpec: + """Build a ModuleSpec for EngramTransformerLayer using Megatron-Core-only modules. + + This mirrors ``get_gpt_layer_local_spec`` but substitutes EngramTransformerLayer + and passes the engram configuration via the spec's ``params`` dict so that + ``build_module`` can forward them to the layer constructor. + + Args: + engram_config: Engram-specific configuration. + vocab_size_across_layers: Mapping from layer_id to per-ngram-level head + vocab sizes, as computed by NgramHashMapping. + Other args match those of ``get_gpt_layer_local_spec``. + + Returns: + ModuleSpec targeting EngramTransformerLayer. + """ + backend = LocalSpecProvider() + + if normalization == "RMSNorm": + layer_norm = backend.layer_norm(rms_norm=True, for_qk=False) + qk_norm = backend.layer_norm(rms_norm=True, for_qk=True) + else: + layer_norm = backend.layer_norm(rms_norm=False, for_qk=False) + qk_norm = backend.layer_norm(rms_norm=False, for_qk=True) + + mlp = _get_mlp_spec(backend, num_experts, moe_grouped_gemm) + + submodules = TransformerLayerSubmodules( + input_layernorm=layer_norm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=backend.column_parallel_linear(), + core_attention=backend.core_attention(), + linear_proj=backend.row_parallel_linear(), + q_layernorm=qk_norm if qk_layernorm else IdentityOp, + k_layernorm=qk_norm if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=layer_norm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ) + + return ModuleSpec( + module=EngramTransformerLayer, + submodules=submodules, + params={ + "engram_config": engram_config, + "engram_vocab_size_across_layers": vocab_size_across_layers, + }, + ) + + +def _get_mlp_spec( + backend: LocalSpecProvider, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Build a dense MLP spec using the backend provider.""" + if num_experts is None: + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=backend.column_parallel_linear(), + linear_fc2=backend.row_parallel_linear(), + ), + ) + else: + from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_for_backend + return get_mlp_module_spec_for_backend( + backend=backend, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + ) diff --git a/megatron/core/models/engram/engram_model.py b/megatron/core/models/engram/engram_model.py new file mode 100644 index 00000000000..14359c32b29 --- /dev/null +++ b/megatron/core/models/engram/engram_model.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Engram-augmented GPT Model. + +Extends GPTModel to support DeepSeek's Engram n-gram hash embedding mechanism. +Before each forward pass, the model pre-computes n-gram hash embeddings from +input_ids and distributes them to the relevant EngramTransformerLayers. +""" + +from __future__ import annotations + +import logging +from typing import Literal, Optional + +from torch import Tensor + +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + +from megatron.core.models.engram.engram_module import EngramConfig, NgramHashMapping + +logger = logging.getLogger(__name__) + + +class EngramGPTModel(GPTModel): + """GPT model augmented with Engram n-gram hash embeddings. + + This model extends GPTModel by: + 1. Maintaining a shared NgramHashMapping for deterministic n-gram hashing. + 2. Before each forward pass, pre-computing hash embeddings from input_ids + and distributing them to EngramTransformerLayers in the decoder. + + Args: + engram_config: Configuration for the Engram module. + All other args are forwarded to GPTModel. + """ + + def __init__( + self, + config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + engram_config: EngramConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal[ + 'learned_absolute', 'rope', 'mrope', 'yarn', 'none' + ] = 'learned_absolute', + rotary_percent: float = 1.0, + rotary_base: int = 10000, + rope_scaling: bool = False, + rope_scaling_factor: float = 8.0, + scatter_embedding_sequence_parallel: bool = True, + seq_len_interpolation_factor: Optional[float] = None, + mtp_block_spec: Optional[ModuleSpec] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ) -> None: + super().__init__( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=fp16_lm_cross_entropy, + parallel_output=parallel_output, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + position_embedding_type=position_embedding_type, + rotary_percent=rotary_percent, + rotary_base=rotary_base, + rope_scaling=rope_scaling, + rope_scaling_factor=rope_scaling_factor, + scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel, + seq_len_interpolation_factor=seq_len_interpolation_factor, + mtp_block_spec=mtp_block_spec, + pg_collection=pg_collection, + vp_stage=vp_stage, + ) + + self.engram_config = engram_config + + self.ngram_hash_mapping = NgramHashMapping( + engram_vocab_size=engram_config.engram_vocab_size, + max_ngram_size=engram_config.max_ngram_size, + n_embed_per_ngram=engram_config.n_embed_per_ngram, + n_head_per_ngram=engram_config.n_head_per_ngram, + layer_ids=engram_config.engram_layer_ids, + tokenizer_name_or_path=engram_config.tokenizer_name_or_path, + pad_id=engram_config.pad_id, + seed=engram_config.seed, + ) + + def _precompute_engram_hashes(self, input_ids: Tensor) -> None: + """Pre-compute n-gram hash embeddings and distribute to engram layers. + + Args: + input_ids: [B, S] token IDs tensor. + """ + if input_ids is None: + return + + # input_ids is [B, S] — compute hashes for all engram layers at once + hash_ids_all_layers = self.ngram_hash_mapping.hash(input_ids) + + device = next(self.decoder.parameters()).device + + for layer in self.decoder.layers: + if hasattr(layer, 'engram') and layer.engram is not None: + layer_id = layer.engram.layer_id + if layer_id in hash_ids_all_layers: + layer.engram.precompute_embeddings( + hash_ids_all_layers[layer_id], device + ) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context=None, + packed_seq_params=None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params=None, + loss_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass with Engram pre-computation. + + Pre-computes n-gram hash embeddings from input_ids before running the + standard GPT forward pass. The embeddings are distributed to each + EngramTransformerLayer and consumed during their forward calls. + """ + if self.pre_process: + self._precompute_engram_hashes(input_ids) + + return super().forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + decoder_input=decoder_input, + labels=labels, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + extra_block_kwargs=extra_block_kwargs, + runtime_gather_output=runtime_gather_output, + inference_params=inference_params, + loss_mask=loss_mask, + padding_mask=padding_mask, + ) diff --git a/megatron/core/models/engram/engram_module.py b/megatron/core/models/engram/engram_module.py new file mode 100644 index 00000000000..aa3f5f39101 --- /dev/null +++ b/megatron/core/models/engram/engram_module.py @@ -0,0 +1,511 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Engram module for Megatron-LM. + +Ported from DeepSeek's Engram reference implementation. The Engram module augments +transformer layers with n-gram hash-based embeddings that are gated against the +hidden states via a multi-head key-query mechanism. + +The module operates in two phases: + 1. Pre-compute: Hash input_ids into n-gram IDs and look up embeddings (called once + per forward pass at the model level). + 2. Forward: Gate the pre-computed embeddings against hidden states and apply a + short causal convolution (called per layer). +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +@dataclass +class EngramConfig: + """Configuration for the Engram module.""" + + engram_vocab_size: List[int] = field( + default_factory=lambda: [129280 * 5, 129280 * 5] + ) + max_ngram_size: int = 3 + n_embed_per_ngram: int = 512 + n_head_per_ngram: int = 8 + engram_layer_ids: List[int] = field(default_factory=lambda: [1, 15]) + pad_id: int = 2 + seed: int = 0 + kernel_size: int = 4 + hc_mult: int = 4 + tokenizer_name_or_path: str = "deepseek-ai/DeepSeek-V3" + + +# --------------------------------------------------------------------------- +# Compressed Tokenizer +# --------------------------------------------------------------------------- + +class CompressedTokenizer: + """Normalizes tokens into a compressed vocabulary via unicode normalization. + + Reduces the effective vocabulary size by mapping visually/semantically similar + tokens to the same compressed ID. This is used as a preprocessing step before + n-gram hashing. + """ + + def __init__(self, tokenizer_name_or_path: str): + from tokenizers import Regex, normalizers + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, trust_remote_code=True + ) + + SENTINEL = "\uE000" + self.normalizer = normalizers.Sequence([ + normalizers.NFKC(), + normalizers.NFD(), + normalizers.StripAccents(), + normalizers.Lowercase(), + normalizers.Replace(Regex(r"[ \t\r\n]+"), " "), + normalizers.Replace(Regex(r"^ $"), SENTINEL), + normalizers.Strip(), + normalizers.Replace(SENTINEL, " "), + ]) + + self.lookup_table, self.num_new_token = self._build_lookup_table() + + def __len__(self): + return self.num_new_token + + def _build_lookup_table(self): + old2new: Dict[int, int] = {} + key2new: Dict[str, int] = {} + new_tokens: List[str] = [] + + vocab_size = len(self.tokenizer) + for tid in range(vocab_size): + text = self.tokenizer.decode([tid], skip_special_tokens=False) + + if "\ufffd" in text: + key = self.tokenizer.convert_ids_to_tokens(tid) + else: + norm = self.normalizer.normalize_str(text) + key = norm if norm else text + + nid = key2new.get(key) + if nid is None: + nid = len(new_tokens) + key2new[key] = nid + new_tokens.append(key) + old2new[tid] = nid + + lookup = np.empty(vocab_size, dtype=np.int64) + for tid in range(vocab_size): + lookup[tid] = old2new[tid] + + return lookup, len(new_tokens) + + def __call__(self, input_ids): + arr = np.asarray(input_ids, dtype=np.int64) + pos_mask = arr >= 0 + out = arr.copy() + valid_ids = arr[pos_mask] + out[pos_mask] = self.lookup_table[valid_ids] + return out + + +# --------------------------------------------------------------------------- +# Primality utilities (avoids sympy dependency) +# --------------------------------------------------------------------------- + +def _is_prime(n: int) -> bool: + if n < 2: + return False + if n < 4: + return True + if n % 2 == 0 or n % 3 == 0: + return False + i = 5 + while i * i <= n: + if n % i == 0 or n % (i + 2) == 0: + return False + i += 6 + return True + + +def _find_next_prime(start: int, seen_primes: set) -> int: + candidate = start + 1 + while True: + if _is_prime(candidate) and candidate not in seen_primes: + return candidate + candidate += 1 + + +# --------------------------------------------------------------------------- +# N-gram Hash Mapping +# --------------------------------------------------------------------------- + +class NgramHashMapping: + """Deterministic n-gram hash mapping for Engram. + + Computes hash IDs for n-grams (bigrams, trigrams, ...) using random multipliers + and modular arithmetic with prime-sized hash tables. Each n-gram level has multiple + heads, each mapping to a distinct prime-sized vocabulary. + """ + + def __init__( + self, + engram_vocab_size: List[int], + max_ngram_size: int, + n_embed_per_ngram: int, + n_head_per_ngram: int, + layer_ids: List[int], + tokenizer_name_or_path: str, + pad_id: int, + seed: int, + ): + self.vocab_size_per_ngram = engram_vocab_size + self.max_ngram_size = max_ngram_size + self.n_embed_per_ngram = n_embed_per_ngram + self.n_head_per_ngram = n_head_per_ngram + self.pad_id = pad_id + self.layer_ids = layer_ids + + self.compressed_tokenizer = CompressedTokenizer( + tokenizer_name_or_path=tokenizer_name_or_path, + ) + self.tokenizer_vocab_size = len(self.compressed_tokenizer) + if self.pad_id is not None: + self.pad_id = int(self.compressed_tokenizer.lookup_table[self.pad_id]) + + max_long = np.iinfo(np.int64).max + M_max = int(max_long // self.tokenizer_vocab_size) + half_bound = max(1, M_max // 2) + PRIME_1 = 10007 + + self.layer_multipliers: Dict[int, np.ndarray] = {} + for layer_id in self.layer_ids: + base_seed = int(seed + PRIME_1 * int(layer_id)) + g = np.random.default_rng(base_seed) + r = g.integers( + low=0, high=half_bound, size=(self.max_ngram_size,), dtype=np.int64 + ) + multipliers = r * 2 + 1 + self.layer_multipliers[layer_id] = multipliers + + self.vocab_size_across_layers = self._calculate_vocab_size_across_layers() + + def _calculate_vocab_size_across_layers(self) -> Dict[int, List[List[int]]]: + seen_primes: set = set() + vocab_size_across_layers: Dict[int, List[List[int]]] = {} + + for layer_id in self.layer_ids: + all_ngram_vocab_sizes: List[List[int]] = [] + for ngram in range(2, self.max_ngram_size + 1): + current_ngram_heads_sizes: List[int] = [] + vocab_size = self.vocab_size_per_ngram[ngram - 2] + current_prime_search_start = vocab_size - 1 + + for _ in range(self.n_head_per_ngram): + found_prime = _find_next_prime( + current_prime_search_start, seen_primes + ) + seen_primes.add(found_prime) + current_ngram_heads_sizes.append(found_prime) + current_prime_search_start = found_prime + + all_ngram_vocab_sizes.append(current_ngram_heads_sizes) + vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes + + return vocab_size_across_layers + + def _get_ngram_hashes( + self, input_ids: np.ndarray, layer_id: int + ) -> np.ndarray: + x = np.asarray(input_ids, dtype=np.int64) + B, T = x.shape + + multipliers = self.layer_multipliers[layer_id] + + def shift_k(k: int) -> np.ndarray: + if k == 0: + return x + shifted = np.pad( + x, ((0, 0), (k, 0)), mode='constant', constant_values=self.pad_id + )[:, :T] + return shifted + + base_shifts = [shift_k(k) for k in range(self.max_ngram_size)] + + all_hashes = [] + for n in range(2, self.max_ngram_size + 1): + n_gram_index = n - 2 + tokens = base_shifts[:n] + mix = tokens[0] * multipliers[0] + for k in range(1, n): + mix = np.bitwise_xor(mix, tokens[k] * multipliers[k]) + + head_vocab_sizes = self.vocab_size_across_layers[layer_id][n_gram_index] + for j in range(self.n_head_per_ngram): + mod = int(head_vocab_sizes[j]) + head_hash = mix % mod + all_hashes.append(head_hash.astype(np.int64, copy=False)) + + return np.stack(all_hashes, axis=2) + + def hash(self, input_ids) -> Dict[int, np.ndarray]: + """Compute n-gram hash IDs for all engram layers. + + Args: + input_ids: Token IDs of shape [B, T] (numpy array or tensor). + + Returns: + Dict mapping layer_id -> hash IDs of shape [B, T, num_total_heads]. + """ + if isinstance(input_ids, torch.Tensor): + input_ids = input_ids.cpu().numpy() + input_ids = self.compressed_tokenizer(input_ids) + hash_ids_for_all_layers = {} + for layer_id in self.layer_ids: + hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes( + input_ids, layer_id=layer_id + ) + return hash_ids_for_all_layers + + +# --------------------------------------------------------------------------- +# Multi-Head Embedding +# --------------------------------------------------------------------------- + +class MultiHeadEmbedding(nn.Module): + """Packs multiple embedding tables (one per hash head) into a single table. + + Uses per-head offsets so that each head's IDs map to a non-overlapping region + of a single large nn.Embedding. + """ + + def __init__(self, list_of_N: List[int], D: int): + super().__init__() + self.num_heads = len(list_of_N) + self.embedding_dim = D + + offsets = [0] + for n in list_of_N[:-1]: + offsets.append(offsets[-1] + n) + + self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long)) + + total_N = sum(list_of_N) + self.embedding = nn.Embedding(num_embeddings=total_N, embedding_dim=D) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + Args: + input_ids: [B, T, num_heads] hash IDs. + + Returns: + [B, T, num_heads, D] embeddings. + """ + shifted_input_ids = input_ids + self.offsets + return self.embedding(shifted_input_ids) + + +# --------------------------------------------------------------------------- +# Short Convolution +# --------------------------------------------------------------------------- + +class ShortConv(nn.Module): + """Depthwise causal 1D convolution with per-group RMSNorm. + + Operates on tensors of shape [B, L, HC_MULT, D], applying a grouped depthwise + convolution across the sequence dimension with causal padding. + """ + + def __init__( + self, + hidden_size: int, + kernel_size: int = 4, + dilation: int = 1, + norm_eps: float = 1e-5, + hc_mult: int = 4, + activation: bool = True, + ): + super().__init__() + self.hc_mult = hc_mult + self.activation = activation + + total_channels = hidden_size * hc_mult + self.conv = nn.Conv1d( + in_channels=total_channels, + out_channels=total_channels, + kernel_size=kernel_size, + groups=total_channels, + bias=False, + padding=(kernel_size - 1) * dilation, + dilation=dilation, + ) + + self.norms = nn.ModuleList([ + nn.RMSNorm(hidden_size, eps=norm_eps) for _ in range(hc_mult) + ]) + + if self.activation: + self.act_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [B, L, HC_MULT, D] + + Returns: + [B, L, HC_MULT, D] + """ + B, T, G, C = x.shape + assert G == self.hc_mult, f"Input groups {G} != hc_mult {self.hc_mult}" + + normed_chunks = [] + for i in range(G): + chunk = x[:, :, i, :] + normed_chunks.append(self.norms[i](chunk)) + + x_norm = torch.cat(normed_chunks, dim=-1) # [B, T, G*C] + x_bct = x_norm.transpose(1, 2) # [B, G*C, T] + y_bct = self.conv(x_bct) + y_bct = y_bct[..., :T] # causal: trim future padding + + if self.activation: + y_bct = self.act_fn(y_bct) + y = y_bct.transpose(1, 2).view(B, T, G, C).contiguous() + + return y + + +# --------------------------------------------------------------------------- +# Engram Module +# --------------------------------------------------------------------------- + +class EngramModule(nn.Module): + """Core Engram module that augments transformer hidden states with n-gram + hash-based embeddings via multi-head gating and causal convolution. + + This module operates in two phases: + 1. ``precompute_embeddings(hash_ids, device)``: Called once per forward pass at + the model level to convert hash IDs into embeddings. + 2. ``forward(hidden_states)``: Called per layer to gate embeddings against + hidden states. + + The hidden states are internally expanded to [B, S, HC_MULT, D] for multi-slot + gating, then collapsed back by averaging across the HC dimension. + """ + + def __init__( + self, + layer_id: int, + hidden_size: int, + engram_config: EngramConfig, + vocab_size_for_layer: List[List[int]], + ): + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hc_mult = engram_config.hc_mult + self.max_ngram_size = engram_config.max_ngram_size + self.n_embed_per_ngram = engram_config.n_embed_per_ngram + self.n_head_per_ngram = engram_config.n_head_per_ngram + + head_vocab_sizes = [x for ngram_sizes in vocab_size_for_layer for x in ngram_sizes] + per_head_dim = engram_config.n_embed_per_ngram // engram_config.n_head_per_ngram + + self.multi_head_embedding = MultiHeadEmbedding( + list_of_N=head_vocab_sizes, + D=per_head_dim, + ) + + self.short_conv = ShortConv( + hidden_size=hidden_size, + kernel_size=engram_config.kernel_size, + dilation=engram_config.max_ngram_size, + hc_mult=engram_config.hc_mult, + ) + + engram_hidden_size = (engram_config.max_ngram_size - 1) * engram_config.n_embed_per_ngram + self.value_proj = nn.Linear(engram_hidden_size, hidden_size) + self.key_projs = nn.ModuleList([ + nn.Linear(engram_hidden_size, hidden_size) + for _ in range(engram_config.hc_mult) + ]) + self.norm1 = nn.ModuleList([ + nn.RMSNorm(hidden_size) for _ in range(engram_config.hc_mult) + ]) + self.norm2 = nn.ModuleList([ + nn.RMSNorm(hidden_size) for _ in range(engram_config.hc_mult) + ]) + + self._cached_embeddings: Optional[torch.Tensor] = None + + def precompute_embeddings( + self, hash_ids: np.ndarray, device: torch.device + ) -> None: + """Pre-compute embeddings from hash IDs. Called once per forward pass. + + Args: + hash_ids: [B, T, num_heads] numpy array of hash IDs for this layer. + device: Target device for the embedding tensor. + """ + hash_tensor = torch.from_numpy(hash_ids).to(device) + embeddings = self.multi_head_embedding(hash_tensor) # [B, T, num_heads, D_head] + self._cached_embeddings = embeddings.flatten(start_dim=-2) # [B, T, engram_hidden] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply engram gating to hidden states using pre-computed embeddings. + + Args: + hidden_states: [S, B, H] tensor in Megatron's sequence-first format. + + Returns: + [S, B, H] engram output to be added as a residual. + """ + assert self._cached_embeddings is not None, ( + "Must call precompute_embeddings() before forward()" + ) + embeddings = self._cached_embeddings # [B, T, engram_hidden] + + # Megatron uses [S, B, H]; convert to [B, S, H] for engram processing + hidden_states_bsh = hidden_states.transpose(0, 1).contiguous() + + # Expand to HC_MULT slots: [B, S, HC_MULT, H] + hidden_states_hc = hidden_states_bsh.unsqueeze(2).expand( + -1, -1, self.hc_mult, -1 + ) + + # Compute gates per HC slot + gates = [] + for hc_idx in range(self.hc_mult): + key = self.key_projs[hc_idx](embeddings) + normed_key = self.norm1[hc_idx](key) + query = hidden_states_hc[:, :, hc_idx, :] + normed_query = self.norm2[hc_idx](query) + gate = (normed_key * normed_query).sum(dim=-1) / math.sqrt(self.hidden_size) + gate = gate.abs().clamp_min(1e-6).sqrt() * gate.sign() + gate = gate.sigmoid().unsqueeze(-1) + gates.append(gate) + gates = torch.stack(gates, dim=2) # [B, S, HC_MULT, 1] + + # Gated value with short convolution + value = gates * self.value_proj(embeddings).unsqueeze(2) # [B, S, HC_MULT, H] + output = value + self.short_conv(value) # [B, S, HC_MULT, H] + + # Collapse HC dimension by averaging, convert back to [S, B, H] + output = output.mean(dim=2) # [B, S, H] + output = output.transpose(0, 1).contiguous() # [S, B, H] + + self._cached_embeddings = None + + return output diff --git a/pretrain_engram.py b/pretrain_engram.py new file mode 100644 index 00000000000..74101de17b7 --- /dev/null +++ b/pretrain_engram.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Pretrain Engram-augmented GPT model.""" + +import time +_PROGRAM_START_TIME = time.time() + +import json +import os +import warnings + +rank = int(os.environ.get('RANK', 0)) +if rank != 0: + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=FutureWarning) + +from functools import partial +from typing import List, Optional, Tuple + +import torch + +from engram_builders import engram_builder +from megatron.core import mpu +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset +from megatron.core.enums import ModelType +from megatron.core.models.engram import EngramGPTModel +from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer +from megatron.core.utils import get_attr_wrapped_model, StragglerDetector +from megatron.training import ( + get_args, + get_timers, + inprocess_restart, + pretrain, + print_rank_0, + set_startup_timestamps, +) +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, + get_blend_and_blend_per_split, + is_first_or_last_pipeline_stage, +) +from model_provider import model_provider + + +stimer = StragglerDetector() + + +def get_batch(data_iterator, vp_stage=None): + """Generate a batch.""" + empty_batch = { + 'tokens': None, + 'labels': None, + 'loss_mask': None, + 'attention_mask': None, + 'position_ids': None, + } + + if not is_first_or_last_pipeline_stage(vp_stage): + return empty_batch.values() + + batch = get_batch_on_this_tp_rank(data_iterator) + + if mpu.is_pipeline_first_stage(ignore_virtual=(vp_stage is None), vp_stage=vp_stage): + total_tokens = batch['tokens'].size(1) + elif mpu.is_pipeline_last_stage(ignore_virtual=(vp_stage is None), vp_stage=vp_stage): + total_tokens = batch['labels'].size(1) + else: + return empty_batch.values() + + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + + +SPIKY_LOSS_FACTOR = 10 + + +def loss_func( + loss_mask: torch.Tensor, + output_tensor: torch.Tensor, + model: Optional[EngramGPTModel] = None, +): + """Loss function for Engram GPT training.""" + args = get_args() + + losses = output_tensor.view(-1).float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses * loss_mask) + + num_tokens = loss_mask.sum().clone().detach().to(torch.int) + report = {'lm loss': torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])} + + rerun_state_machine = get_rerun_state_machine() + if args.check_for_nan_in_loss_and_grad: + rerun_state_machine.validate_result( + result=loss, + rejection_func=torch.isnan, + message="found NaN in local forward loss calculation", + tolerance=0.0, + fatal=True, + ) + rerun_state_machine.validate_result( + result=loss, + rejection_func=torch.isinf, + message="found Inf in local forward loss calculation", + tolerance=0.0, + fatal=True, + ) + if args.check_for_spiky_loss: + rerun_state_machine.validate_result( + result=loss, + rejection_func=partial( + rerun_state_machine.is_unexpectedly_large, + threshold=SPIKY_LOSS_FACTOR, + context="loss", + ), + message="Spiky loss", + tolerance=0.0, + fatal=False, + ) + + return loss, num_tokens, report + + +def forward_step(data_iterator, model: EngramGPTModel): + """Forward training step.""" + timers = get_timers() + + timers('batch-generator', log_level=2).start() + + global stimer + + with stimer(bdata=True): + vp_stage = get_attr_wrapped_model(model, "vp_stage") + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator, vp_stage + ) + + timers('batch-generator').stop() + + with stimer: + output_tensor = model( + tokens, + position_ids, + attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + return output_tensor, partial(loss_func, loss_mask, model=model) + + +def is_dataset_built_on_rank(vp_stage=None): + if mpu.get_tensor_model_parallel_rank() != 0: + return False + return is_first_or_last_pipeline_stage(vp_stage) + + +def core_gpt_dataset_config_from_args(args): + tokenizer = build_tokenizer(args) + + blend, blend_per_split = get_blend_and_blend_per_split(args) + + sequences_per_dataset = None + if args.per_dataset_sequences_path is not None: + with open(args.per_dataset_sequences_path, "r") as f: + sequences_per_dataset = json.load(f) + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=blend, + blend_per_split=blend_per_split, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + object_storage_cache_path=args.object_storage_cache_path, + mid_level_dataset_surplus=args.mid_level_dataset_surplus, + allow_ambiguous_pad_tokens=args.allow_ambiguous_pad_tokens, + fast_cache_load=args.dataloader_fast_cache_load, + sequences_per_dataset=sequences_per_dataset, + defer_npy_index_mmap=args.dataloader_defer_npy_index_mmap, + context_parallel_size=args.context_parallel_size, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None): + """Build the train, validation, and test datasets.""" + args = get_args() + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for Engram GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + partial(is_dataset_built_on_rank, vp_stage=vp_stage), + config, + ).build() + + print_rank_0("> finished creating Engram GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +def add_engram_args(parser): + """Add Engram-specific command line arguments.""" + group = parser.add_argument_group(title='Engram') + group.add_argument( + '--engram-layer-ids', + type=str, + default='1,15', + help='Comma-separated list of 1-based layer IDs that get Engram modules.', + ) + group.add_argument( + '--engram-max-ngram-size', + type=int, + default=3, + help='Maximum n-gram size for Engram hashing.', + ) + group.add_argument( + '--engram-n-embed-per-ngram', + type=int, + default=512, + help='Embedding dimension per n-gram level in Engram.', + ) + group.add_argument( + '--engram-n-head-per-ngram', + type=int, + default=8, + help='Number of hash heads per n-gram level.', + ) + group.add_argument( + '--engram-kernel-size', + type=int, + default=4, + help='Kernel size for Engram short convolution.', + ) + group.add_argument( + '--engram-hc-mult', + type=int, + default=4, + help='Hyper-connection multiplier for Engram gating.', + ) + group.add_argument( + '--engram-pad-id', + type=int, + default=2, + help='Pad token ID for Engram hash computation.', + ) + group.add_argument( + '--engram-seed', + type=int, + default=0, + help='Random seed for Engram hash multiplier generation.', + ) + group.add_argument( + '--engram-tokenizer', + type=str, + default='deepseek-ai/DeepSeek-V3', + help='Tokenizer name/path for Engram compressed tokenizer.', + ) + return parser + + +if __name__ == "__main__": + _MAIN_ENTRY_TIME = time.time() + set_startup_timestamps(program_start=_PROGRAM_START_TIME, main_entry=_MAIN_ENTRY_TIME) + + train_valid_test_datasets_provider.is_distributed = True + + pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) + + pretrain( + train_valid_test_datasets_provider, + partial(model_provider, engram_builder), + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + store=store, + extra_args_provider=add_engram_args, + ) From db7bd21d9715f7b5d0117a7d0d565ed52d491c8e Mon Sep 17 00:00:00 2001 From: ilml Date: Thu, 5 Mar 2026 02:32:41 +0000 Subject: [PATCH 2/2] add unit test --- .../core/models/engram/test_engram_shapes.py | 191 ++++++++++++ pretrain_engram.py | 295 ------------------ 2 files changed, 191 insertions(+), 295 deletions(-) create mode 100644 megatron/core/models/engram/test_engram_shapes.py delete mode 100644 pretrain_engram.py diff --git a/megatron/core/models/engram/test_engram_shapes.py b/megatron/core/models/engram/test_engram_shapes.py new file mode 100644 index 00000000000..51c5819b234 --- /dev/null +++ b/megatron/core/models/engram/test_engram_shapes.py @@ -0,0 +1,191 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for Engram module shapes. + +Verifies that all Engram components can be instantiated and produce correct +output shapes. Tests 1-3 use synthetic data (no tokenizer needed). The +end-to-end test requires the HuggingFace tokenizer from EngramConfig. + +Usage: + pytest megatron/core/models/engram/test_engram_shapes.py -v +""" + +import numpy as np +import pytest +import torch + +from megatron.core.models.engram.engram_module import ( + EngramConfig, + EngramModule, + MultiHeadEmbedding, + ShortConv, +) + +B, S = 2, 16 +HIDDEN_SIZE = 128 +HC_MULT = 4 + + +def _make_engram_config(**overrides) -> EngramConfig: + defaults = dict( + engram_vocab_size=[1000, 1000], + max_ngram_size=3, + n_embed_per_ngram=512, + n_head_per_ngram=8, + engram_layer_ids=[1], + pad_id=2, + seed=0, + kernel_size=4, + hc_mult=HC_MULT, + ) + defaults.update(overrides) + return EngramConfig(**defaults) + + +def _make_vocab_sizes(cfg: EngramConfig): + """Build synthetic per-head vocab sizes (no tokenizer needed).""" + num_ngram_levels = cfg.max_ngram_size - 1 + sizes = [] + base = 1009 + for _ in range(num_ngram_levels): + head_sizes = [base + h * 10 for h in range(cfg.n_head_per_ngram)] + sizes.append(head_sizes) + base += cfg.n_head_per_ngram * 10 + return sizes + + +class TestShortConv: + def test_output_shape(self): + conv = ShortConv( + hidden_size=HIDDEN_SIZE, kernel_size=4, dilation=3, hc_mult=HC_MULT, + ) + x = torch.randn(B, S, HC_MULT, HIDDEN_SIZE) + y = conv(x) + assert y.shape == (B, S, HC_MULT, HIDDEN_SIZE) + + def test_causal_padding(self): + """Output length must equal input length regardless of kernel/dilation.""" + for kernel, dilation in [(3, 1), (4, 3), (7, 2)]: + conv = ShortConv( + hidden_size=HIDDEN_SIZE, + kernel_size=kernel, + dilation=dilation, + hc_mult=HC_MULT, + ) + x = torch.randn(B, S, HC_MULT, HIDDEN_SIZE) + y = conv(x) + assert y.shape[1] == S, f"kernel={kernel}, dilation={dilation}" + + +class TestMultiHeadEmbedding: + def test_output_shape(self): + list_of_N = [1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, + 1051, 1061, 1063, 1069, 1087, 1091, 1093, 1097] + D = 64 + num_heads = len(list_of_N) + emb = MultiHeadEmbedding(list_of_N=list_of_N, D=D) + + ids = torch.stack( + [torch.randint(0, n, (B, S)) for n in list_of_N], dim=2 + ) + out = emb(ids) + assert out.shape == (B, S, num_heads, D) + + def test_offset_isolation(self): + """Each head's IDs should index a non-overlapping region.""" + list_of_N = [100, 200, 300] + emb = MultiHeadEmbedding(list_of_N=list_of_N, D=8) + assert emb.offsets.tolist() == [0, 100, 300] + + +class TestEngramModule: + def _build_module(self, cfg=None): + cfg = cfg or _make_engram_config() + vocab_sizes = _make_vocab_sizes(cfg) + return EngramModule( + layer_id=cfg.engram_layer_ids[0], + hidden_size=HIDDEN_SIZE, + engram_config=cfg, + vocab_size_for_layer=vocab_sizes, + ), cfg + + def test_output_shape(self): + module, cfg = self._build_module() + num_heads = (cfg.max_ngram_size - 1) * cfg.n_head_per_ngram + hash_ids = np.random.randint( + 0, 100, size=(B, S, num_heads), + ).astype(np.int64) + + module.precompute_embeddings(hash_ids, device=torch.device("cpu")) + hidden = torch.randn(S, B, HIDDEN_SIZE) + out = module(hidden) + assert out.shape == (S, B, HIDDEN_SIZE) + + def test_cache_cleared_after_forward(self): + module, cfg = self._build_module() + num_heads = (cfg.max_ngram_size - 1) * cfg.n_head_per_ngram + hash_ids = np.random.randint( + 0, 100, size=(B, S, num_heads), + ).astype(np.int64) + + module.precompute_embeddings(hash_ids, device=torch.device("cpu")) + module(torch.randn(S, B, HIDDEN_SIZE)) + assert module._cached_embeddings is None + + def test_forward_without_precompute_raises(self): + module, _ = self._build_module() + with pytest.raises(AssertionError, match="precompute_embeddings"): + module(torch.randn(S, B, HIDDEN_SIZE)) + + +class TestEndToEnd: + """End-to-end test through NgramHashMapping. Requires tokenizer.""" + + @pytest.fixture() + def mapping_and_cfg(self): + from megatron.core.models.engram.engram_module import NgramHashMapping + + cfg = _make_engram_config(engram_layer_ids=[1, 15]) + try: + mapping = NgramHashMapping( + engram_vocab_size=cfg.engram_vocab_size, + max_ngram_size=cfg.max_ngram_size, + n_embed_per_ngram=cfg.n_embed_per_ngram, + n_head_per_ngram=cfg.n_head_per_ngram, + layer_ids=cfg.engram_layer_ids, + tokenizer_name_or_path=cfg.tokenizer_name_or_path, + pad_id=cfg.pad_id, + seed=cfg.seed, + ) + except Exception as e: + pytest.skip(f"tokenizer not available: {e}") + return mapping, cfg + + def test_hash_shapes(self, mapping_and_cfg): + mapping, cfg = mapping_and_cfg + num_heads = (cfg.max_ngram_size - 1) * cfg.n_head_per_ngram + fake_ids = np.random.randint(0, 1000, size=(B, S)) + result = mapping.hash(fake_ids) + for layer_id in cfg.engram_layer_ids: + assert result[layer_id].shape == (B, S, num_heads) + + def test_full_forward(self, mapping_and_cfg): + mapping, cfg = mapping_and_cfg + fake_ids = np.random.randint(0, 1000, size=(B, S)) + hash_all = mapping.hash(fake_ids) + + hidden = torch.randn(S, B, HIDDEN_SIZE) + for layer_id in cfg.engram_layer_ids: + mod = EngramModule( + layer_id=layer_id, + hidden_size=HIDDEN_SIZE, + engram_config=cfg, + vocab_size_for_layer=mapping.vocab_size_across_layers[layer_id], + ) + mod.precompute_embeddings(hash_all[layer_id], device=torch.device("cpu")) + out = mod(hidden) + assert out.shape == (S, B, HIDDEN_SIZE) + hidden = out + hidden + + assert hidden.shape == (S, B, HIDDEN_SIZE) diff --git a/pretrain_engram.py b/pretrain_engram.py deleted file mode 100644 index 74101de17b7..00000000000 --- a/pretrain_engram.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -"""Pretrain Engram-augmented GPT model.""" - -import time -_PROGRAM_START_TIME = time.time() - -import json -import os -import warnings - -rank = int(os.environ.get('RANK', 0)) -if rank != 0: - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings("ignore", category=FutureWarning) - -from functools import partial -from typing import List, Optional, Tuple - -import torch - -from engram_builders import engram_builder -from megatron.core import mpu -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset -from megatron.core.enums import ModelType -from megatron.core.models.engram import EngramGPTModel -from megatron.core.rerun_state_machine import get_rerun_state_machine -from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer -from megatron.core.utils import get_attr_wrapped_model, StragglerDetector -from megatron.training import ( - get_args, - get_timers, - inprocess_restart, - pretrain, - print_rank_0, - set_startup_timestamps, -) -from megatron.training.utils import ( - get_batch_on_this_cp_rank, - get_batch_on_this_tp_rank, - get_blend_and_blend_per_split, - is_first_or_last_pipeline_stage, -) -from model_provider import model_provider - - -stimer = StragglerDetector() - - -def get_batch(data_iterator, vp_stage=None): - """Generate a batch.""" - empty_batch = { - 'tokens': None, - 'labels': None, - 'loss_mask': None, - 'attention_mask': None, - 'position_ids': None, - } - - if not is_first_or_last_pipeline_stage(vp_stage): - return empty_batch.values() - - batch = get_batch_on_this_tp_rank(data_iterator) - - if mpu.is_pipeline_first_stage(ignore_virtual=(vp_stage is None), vp_stage=vp_stage): - total_tokens = batch['tokens'].size(1) - elif mpu.is_pipeline_last_stage(ignore_virtual=(vp_stage is None), vp_stage=vp_stage): - total_tokens = batch['labels'].size(1) - else: - return empty_batch.values() - - batch = get_batch_on_this_cp_rank(batch) - - return batch.values() - - -SPIKY_LOSS_FACTOR = 10 - - -def loss_func( - loss_mask: torch.Tensor, - output_tensor: torch.Tensor, - model: Optional[EngramGPTModel] = None, -): - """Loss function for Engram GPT training.""" - args = get_args() - - losses = output_tensor.view(-1).float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses * loss_mask) - - num_tokens = loss_mask.sum().clone().detach().to(torch.int) - report = {'lm loss': torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])} - - rerun_state_machine = get_rerun_state_machine() - if args.check_for_nan_in_loss_and_grad: - rerun_state_machine.validate_result( - result=loss, - rejection_func=torch.isnan, - message="found NaN in local forward loss calculation", - tolerance=0.0, - fatal=True, - ) - rerun_state_machine.validate_result( - result=loss, - rejection_func=torch.isinf, - message="found Inf in local forward loss calculation", - tolerance=0.0, - fatal=True, - ) - if args.check_for_spiky_loss: - rerun_state_machine.validate_result( - result=loss, - rejection_func=partial( - rerun_state_machine.is_unexpectedly_large, - threshold=SPIKY_LOSS_FACTOR, - context="loss", - ), - message="Spiky loss", - tolerance=0.0, - fatal=False, - ) - - return loss, num_tokens, report - - -def forward_step(data_iterator, model: EngramGPTModel): - """Forward training step.""" - timers = get_timers() - - timers('batch-generator', log_level=2).start() - - global stimer - - with stimer(bdata=True): - vp_stage = get_attr_wrapped_model(model, "vp_stage") - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator, vp_stage - ) - - timers('batch-generator').stop() - - with stimer: - output_tensor = model( - tokens, - position_ids, - attention_mask, - labels=labels, - loss_mask=loss_mask, - ) - - return output_tensor, partial(loss_func, loss_mask, model=model) - - -def is_dataset_built_on_rank(vp_stage=None): - if mpu.get_tensor_model_parallel_rank() != 0: - return False - return is_first_or_last_pipeline_stage(vp_stage) - - -def core_gpt_dataset_config_from_args(args): - tokenizer = build_tokenizer(args) - - blend, blend_per_split = get_blend_and_blend_per_split(args) - - sequences_per_dataset = None - if args.per_dataset_sequences_path is not None: - with open(args.per_dataset_sequences_path, "r") as f: - sequences_per_dataset = json.load(f) - - return GPTDatasetConfig( - random_seed=args.seed, - sequence_length=args.seq_length, - blend=blend, - blend_per_split=blend_per_split, - split=args.split, - num_dataset_builder_threads=args.num_dataset_builder_threads, - path_to_cache=args.data_cache_path, - mmap_bin_files=args.mmap_bin_files, - tokenizer=tokenizer, - reset_position_ids=args.reset_position_ids, - reset_attention_mask=args.reset_attention_mask, - eod_mask_loss=args.eod_mask_loss, - create_attention_mask=args.create_attention_mask_in_dataloader, - object_storage_cache_path=args.object_storage_cache_path, - mid_level_dataset_surplus=args.mid_level_dataset_surplus, - allow_ambiguous_pad_tokens=args.allow_ambiguous_pad_tokens, - fast_cache_load=args.dataloader_fast_cache_load, - sequences_per_dataset=sequences_per_dataset, - defer_npy_index_mmap=args.dataloader_defer_npy_index_mmap, - context_parallel_size=args.context_parallel_size, - ) - - -def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None): - """Build the train, validation, and test datasets.""" - args = get_args() - config = core_gpt_dataset_config_from_args(args) - - if args.mock_data: - dataset_type = MockGPTDataset - else: - dataset_type = GPTDataset - - print_rank_0("> building train, validation, and test datasets for Engram GPT ...") - - train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( - dataset_type, - train_val_test_num_samples, - partial(is_dataset_built_on_rank, vp_stage=vp_stage), - config, - ).build() - - print_rank_0("> finished creating Engram GPT datasets ...") - - return train_ds, valid_ds, test_ds - - -def add_engram_args(parser): - """Add Engram-specific command line arguments.""" - group = parser.add_argument_group(title='Engram') - group.add_argument( - '--engram-layer-ids', - type=str, - default='1,15', - help='Comma-separated list of 1-based layer IDs that get Engram modules.', - ) - group.add_argument( - '--engram-max-ngram-size', - type=int, - default=3, - help='Maximum n-gram size for Engram hashing.', - ) - group.add_argument( - '--engram-n-embed-per-ngram', - type=int, - default=512, - help='Embedding dimension per n-gram level in Engram.', - ) - group.add_argument( - '--engram-n-head-per-ngram', - type=int, - default=8, - help='Number of hash heads per n-gram level.', - ) - group.add_argument( - '--engram-kernel-size', - type=int, - default=4, - help='Kernel size for Engram short convolution.', - ) - group.add_argument( - '--engram-hc-mult', - type=int, - default=4, - help='Hyper-connection multiplier for Engram gating.', - ) - group.add_argument( - '--engram-pad-id', - type=int, - default=2, - help='Pad token ID for Engram hash computation.', - ) - group.add_argument( - '--engram-seed', - type=int, - default=0, - help='Random seed for Engram hash multiplier generation.', - ) - group.add_argument( - '--engram-tokenizer', - type=str, - default='deepseek-ai/DeepSeek-V3', - help='Tokenizer name/path for Engram compressed tokenizer.', - ) - return parser - - -if __name__ == "__main__": - _MAIN_ENTRY_TIME = time.time() - set_startup_timestamps(program_start=_PROGRAM_START_TIME, main_entry=_MAIN_ENTRY_TIME) - - train_valid_test_datasets_provider.is_distributed = True - - pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain) - - pretrain( - train_valid_test_datasets_provider, - partial(model_provider, engram_builder), - ModelType.encoder_or_decoder, - forward_step, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - store=store, - extra_args_provider=add_engram_args, - )