Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
"editor.rulers": [
120
],
"autoDocstring.docstringFormat": "google-notypes"
"autoDocstring.docstringFormat": "google-notypes",
"search.exclude": { "**/logs/**": true },
}
102 changes: 89 additions & 13 deletions bionemo-recipes/models/esm2/modeling_esm_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
Adapted from `modeling_esm.py` in huggingface/transformers.
"""

from contextlib import nullcontext
from typing import ClassVar, Literal, Optional, Unpack

# TODO: put import guard around transformer_engine here, with an informative error message around
# installation and the nvidia docker container.
import torch
import torch.cuda.nvtx as nvtx
import transformer_engine.common.recipe
import transformer_engine.pytorch
from torch import nn
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -54,6 +57,15 @@
"AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification",
}

# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86
FP8_RECIPES = (
transformer_engine.common.recipe.MXFP8BlockScaling,
transformer_engine.common.recipe.DelayedScaling,
transformer_engine.common.recipe.Float8CurrentScaling,
transformer_engine.common.recipe.Float8BlockScaling,
)
FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling


class NVEsmConfig(EsmConfig):
"""NVEsmConfig is a configuration for the NVEsm model."""
Expand Down Expand Up @@ -164,6 +176,9 @@ def _init_method(x):
for i in range(config.num_hidden_layers)
]
)
self._fp8_recipe: object | None = None
self._fp4_recipe: object | None = None
self._layer_precision: dict[int, str | None] | None = None
self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
Expand All @@ -173,6 +188,61 @@ def _init_method(x):
if config.position_embedding_type == "rotary":
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)

def initialize_quantization(
self,
fp8_layers: list[int] | None,
fp4_layers: list[int] | None,
fp8_recipe: object | None = None,
fp4_recipe: object | None = None,
) -> None:
"""Build the per-layer quantization precision map.

Must be called after model creation and sharding (FSDP/DDP/mFSDP) but before training.
Each layer is tagged as ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). The recipe
objects are stored once on the encoder rather than duplicated per-layer, ensuring the
map is trivially pickleable.

Args:
fp8_layers: 0-indexed layer numbers to run in FP8, or None.
fp4_layers: 0-indexed layer numbers to run in FP4, or None.
fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
"""
self._fp8_recipe = fp8_recipe
self._fp4_recipe = fp4_recipe
fp8_layers_set = set(fp8_layers) if fp8_layers else set()
fp4_layers_set = set(fp4_layers) if fp4_layers else set()
self._layer_precision = {}
for layer_number in range(len(self.layers)):
if layer_number in fp8_layers_set:
self._layer_precision[layer_number] = "fp8"
elif layer_number in fp4_layers_set:
self._layer_precision[layer_number] = "fp4"
else:
self._layer_precision[layer_number] = None

def get_layer_autocast(self, layer_number: int):
"""Return the appropriate TE autocast context manager for a given layer.

The context interacts with the outer FP8 autocast in the training script:
- FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
- FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
- BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.

Args:
layer_number: The 0-indexed layer number.

Returns:
A context manager for the layer's quantization mode.
"""
precision = self._layer_precision.get(layer_number) if self._layer_precision is not None else None
if precision == "fp8":
return nullcontext()
elif precision == "fp4":
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
else:
return transformer_engine.pytorch.autocast(enabled=False)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -198,24 +268,30 @@ def forward(
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)

for layer_module in self.layers:
# Per-layer quantization context (FP8, FP4, or BF16) is determined by get_layer_autocast().
for layer_number, layer_module in enumerate(self.layers):
if kwargs.get("output_hidden_states", False):
all_hidden_states = (*all_hidden_states, hidden_states)

hidden_states = layer_module(
hidden_states,
attention_mask,
rotary_pos_emb=te_rope_emb,
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
max_seqlen_q=kwargs.get("max_length_q", None),
max_seqlen_kv=kwargs.get("max_length_k", None),
pad_between_seqs=kwargs.get("pad_between_seqs", None),
)
nvtx.range_push(f"encoder_layer_{layer_number}")
with self.get_layer_autocast(layer_number):
hidden_states = layer_module(
hidden_states,
attention_mask,
rotary_pos_emb=te_rope_emb,
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
max_seqlen_q=kwargs.get("max_length_q", None),
max_seqlen_kv=kwargs.get("max_length_k", None),
pad_between_seqs=kwargs.get("pad_between_seqs", None),
)
nvtx.range_pop() # encoder_layer_N

nvtx.range_push("emb_layer_norm_after")
hidden_states = self.emb_layer_norm_after(hidden_states)
nvtx.range_pop() # emb_layer_norm_after

if kwargs.get("output_hidden_states", False):
all_hidden_states = (*all_hidden_states, hidden_states)
Expand Down
242 changes: 242 additions & 0 deletions bionemo-recipes/models/esm2/tests/test_layer_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for NVEsmEncoder.initialize_quantization and get_layer_autocast."""

from contextlib import nullcontext
from unittest.mock import patch

import pytest
import transformer_engine.common.recipe
import transformer_engine.pytorch

from modeling_esm_te import NVEsmConfig, NVEsmEncoder


@pytest.fixture
def encoder():
"""Create a small NVEsmEncoder on CUDA for testing."""
config = NVEsmConfig(
hidden_size=320,
intermediate_size=1280,
num_hidden_layers=6,
num_attention_heads=20,
max_position_embeddings=1026,
)
return NVEsmEncoder(config)


class TestInitializeQuantization:
"""Tests for NVEsmEncoder.initialize_quantization."""

def test_all_fp8(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1, 2, 3, 4, 5],
fp4_layers=None,
fp8_recipe=fp8_recipe,
fp4_recipe=None,
)
assert encoder._fp8_recipe is fp8_recipe
assert encoder._fp4_recipe is None
assert all(encoder._layer_precision[i] == "fp8" for i in range(6))

def test_all_fp4(self, encoder):
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=None,
fp4_layers=[0, 1, 2, 3, 4, 5],
fp8_recipe=None,
fp4_recipe=fp4_recipe,
)
assert encoder._fp8_recipe is None
assert encoder._fp4_recipe is fp4_recipe
assert all(encoder._layer_precision[i] == "fp4" for i in range(6))

def test_all_bf16(self, encoder):
encoder.initialize_quantization(
fp8_layers=None,
fp4_layers=None,
fp8_recipe=None,
fp4_recipe=None,
)
assert all(encoder._layer_precision[i] is None for i in range(6))

def test_mixed_fp8_fp4(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1, 2],
fp4_layers=[3, 4, 5],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
for i in range(3):
assert encoder._layer_precision[i] == "fp8"
for i in range(3, 6):
assert encoder._layer_precision[i] == "fp4"

def test_mixed_fp8_bf16(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
encoder.initialize_quantization(
fp8_layers=[0, 2, 4],
fp4_layers=None,
fp8_recipe=fp8_recipe,
fp4_recipe=None,
)
assert encoder._layer_precision[0] == "fp8"
assert encoder._layer_precision[1] is None
assert encoder._layer_precision[2] == "fp8"
assert encoder._layer_precision[3] is None
assert encoder._layer_precision[4] == "fp8"
assert encoder._layer_precision[5] is None

def test_mixed_all_three(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1],
fp4_layers=[4, 5],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
assert encoder._layer_precision[0] == "fp8"
assert encoder._layer_precision[1] == "fp8"
assert encoder._layer_precision[2] is None # BF16
assert encoder._layer_precision[3] is None # BF16
assert encoder._layer_precision[4] == "fp4"
assert encoder._layer_precision[5] == "fp4"

def test_empty_lists_treated_as_none(self, encoder):
encoder.initialize_quantization(
fp8_layers=[],
fp4_layers=[],
fp8_recipe=None,
fp4_recipe=None,
)
assert all(encoder._layer_precision[i] is None for i in range(6))

def test_covers_all_layers(self, encoder):
encoder.initialize_quantization(
fp8_layers=[0],
fp4_layers=None,
fp8_recipe=transformer_engine.common.recipe.DelayedScaling(),
fp4_recipe=None,
)
assert len(encoder._layer_precision) == 6

def test_recipes_stored_as_attributes(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0],
fp4_layers=[1],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
# Recipes are stored once, not duplicated per-layer in the map.
assert encoder._fp8_recipe is fp8_recipe
assert encoder._fp4_recipe is fp4_recipe
# The map only contains strings, not recipe objects.
for v in encoder._layer_precision.values():
assert v is None or isinstance(v, str)


class TestGetLayerAutocast:
"""Tests for NVEsmEncoder.get_layer_autocast."""

def test_fp8_layer_returns_nullcontext(self, encoder):
encoder.initialize_quantization(
fp8_layers=[0],
fp4_layers=None,
fp8_recipe=transformer_engine.common.recipe.DelayedScaling(),
fp4_recipe=None,
)
ctx = encoder.get_layer_autocast(0)
assert isinstance(ctx, nullcontext)

def test_fp4_layer_returns_te_autocast(self, encoder):
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=None,
fp4_layers=[0],
fp8_recipe=None,
fp4_recipe=fp4_recipe,
)
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "fp4_context"
ctx = encoder.get_layer_autocast(0)
mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe)
assert ctx == "fp4_context"

def test_bf16_layer_returns_te_autocast_disabled(self, encoder):
encoder.initialize_quantization(
fp8_layers=None,
fp4_layers=None,
fp8_recipe=None,
fp4_recipe=None,
)
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "bf16_context"
ctx = encoder.get_layer_autocast(0)
mock_autocast.assert_called_once_with(enabled=False)
assert ctx == "bf16_context"

def test_uninitialized_defaults_to_bf16(self, encoder):
"""When initialize_quantization was never called, all layers default to BF16."""
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "bf16_context"
ctx = encoder.get_layer_autocast(0)
mock_autocast.assert_called_once_with(enabled=False)
assert ctx == "bf16_context"

def test_mixed_layers_return_correct_contexts(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1],
fp4_layers=[2, 3],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
# FP8 layers -> nullcontext
assert isinstance(encoder.get_layer_autocast(0), nullcontext)
assert isinstance(encoder.get_layer_autocast(1), nullcontext)

# FP4 and BF16 layers -> te.pytorch.autocast (not nullcontext)
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "fp4_context"
encoder.get_layer_autocast(2)
mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe)

with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "bf16_context"
encoder.get_layer_autocast(4)
mock_autocast.assert_called_with(enabled=False)

def test_layer_precision_map_is_pickleable(self, encoder):
"""The _layer_precision map should be trivially pickleable (only strings/None)."""
import pickle

fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1],
fp4_layers=[2, 3],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
roundtripped = pickle.loads(pickle.dumps(encoder._layer_precision))
assert roundtripped == encoder._layer_precision
Comment on lines +16 to +242
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add at least one TE-vs-reference golden-value parity test in this module.

These tests cover routing/context behavior, but they do not assert numerical parity between the TE model and the reference ESM model for a fixed input/seed.

As per coding guidelines: "In bionemo-recipes/models/, create golden value tests proving that the TransformerEngine model matches the reference model".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/esm2/tests/test_layer_quantization.py` around lines 16
- 242, Add a golden-value parity test that runs the TransformerEngine ESM
encoder (NVEsmEncoder) and the reference ESM encoder on the same deterministic
input/seed and asserts numerical parity (e.g., final token logits or pooled
embeddings) within a small tolerance; create a new test function (e.g.,
test_te_vs_reference_golden_value_parity) in this module that uses
torch.manual_seed, a small random input tensor on CUDA, constructs an
NVEsmEncoder via NVEsmConfig and constructs the reference ESM model (import the
reference model used in the repo), runs both forward passes with identical
settings, and asserts outputs are close with pytest.approx or torch.allclose;
ensure the test uses the existing encoder fixture pattern/device and keeps the
comparison deterministic and tolerant to tiny numeric differences.

Loading