diff --git a/.vscode/settings.json b/.vscode/settings.json index 41bac6a7e9..e6a3603dab 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -26,5 +26,6 @@ "editor.rulers": [ 120 ], - "autoDocstring.docstringFormat": "google-notypes" + "autoDocstring.docstringFormat": "google-notypes", + "search.exclude": { "**/logs/**": true }, } diff --git a/bionemo-recipes/models/esm2/modeling_esm_te.py b/bionemo-recipes/models/esm2/modeling_esm_te.py index be4049df06..adf9921e12 100644 --- a/bionemo-recipes/models/esm2/modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/modeling_esm_te.py @@ -22,11 +22,14 @@ Adapted from `modeling_esm.py` in huggingface/transformers. """ +from contextlib import nullcontext from typing import ClassVar, Literal, Optional, Unpack # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. import torch +import torch.cuda.nvtx as nvtx +import transformer_engine.common.recipe import transformer_engine.pytorch from torch import nn from torch.nn import CrossEntropyLoss @@ -54,6 +57,15 @@ "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", } +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = ( + transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling, +) +FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling + class NVEsmConfig(EsmConfig): """NVEsmConfig is a configuration for the NVEsm model.""" @@ -164,6 +176,9 @@ def _init_method(x): for i in range(config.num_hidden_layers) ] ) + self._fp8_recipe: object | None = None + self._fp4_recipe: object | None = None + self._layer_precision: dict[int, str | None] | None = None self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, @@ -173,6 +188,61 @@ def _init_method(x): if config.position_embedding_type == "rotary": self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + def initialize_quantization( + self, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, + fp8_recipe: object | None = None, + fp4_recipe: object | None = None, + ) -> None: + """Build the per-layer quantization precision map. + + Must be called after model creation and sharding (FSDP/DDP/mFSDP) but before training. + Each layer is tagged as ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). The recipe + objects are stored once on the encoder rather than duplicated per-layer, ensuring the + map is trivially pickleable. + + Args: + fp8_layers: 0-indexed layer numbers to run in FP8, or None. + fp4_layers: 0-indexed layer numbers to run in FP4, or None. + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() + self._layer_precision = {} + for layer_number in range(len(self.layers)): + if layer_number in fp8_layers_set: + self._layer_precision[layer_number] = "fp8" + elif layer_number in fp4_layers_set: + self._layer_precision[layer_number] = "fp4" + else: + self._layer_precision[layer_number] = None + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self._layer_precision.get(layer_number) if self._layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, hidden_states: torch.Tensor, @@ -198,24 +268,30 @@ def forward( te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) - for layer_module in self.layers: + # Per-layer quantization context (FP8, FP4, or BF16) is determined by get_layer_autocast(). + for layer_number, layer_module in enumerate(self.layers): if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - hidden_states = layer_module( - hidden_states, - attention_mask, - rotary_pos_emb=te_rope_emb, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) + nvtx.range_push(f"encoder_layer_{layer_number}") + with self.get_layer_autocast(layer_number): + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + nvtx.range_pop() # encoder_layer_N + nvtx.range_push("emb_layer_norm_after") hidden_states = self.emb_layer_norm_after(hidden_states) + nvtx.range_pop() # emb_layer_norm_after if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) diff --git a/bionemo-recipes/models/esm2/tests/test_layer_quantization.py b/bionemo-recipes/models/esm2/tests/test_layer_quantization.py new file mode 100644 index 0000000000..ba8fc30d7c --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/test_layer_quantization.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for NVEsmEncoder.initialize_quantization and get_layer_autocast.""" + +from contextlib import nullcontext +from unittest.mock import patch + +import pytest +import transformer_engine.common.recipe +import transformer_engine.pytorch + +from modeling_esm_te import NVEsmConfig, NVEsmEncoder + + +@pytest.fixture +def encoder(): + """Create a small NVEsmEncoder on CUDA for testing.""" + config = NVEsmConfig( + hidden_size=320, + intermediate_size=1280, + num_hidden_layers=6, + num_attention_heads=20, + max_position_embeddings=1026, + ) + return NVEsmEncoder(config) + + +class TestInitializeQuantization: + """Tests for NVEsmEncoder.initialize_quantization.""" + + def test_all_fp8(self, encoder): + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + encoder.initialize_quantization( + fp8_layers=[0, 1, 2, 3, 4, 5], + fp4_layers=None, + fp8_recipe=fp8_recipe, + fp4_recipe=None, + ) + assert encoder._fp8_recipe is fp8_recipe + assert encoder._fp4_recipe is None + assert all(encoder._layer_precision[i] == "fp8" for i in range(6)) + + def test_all_fp4(self, encoder): + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + encoder.initialize_quantization( + fp8_layers=None, + fp4_layers=[0, 1, 2, 3, 4, 5], + fp8_recipe=None, + fp4_recipe=fp4_recipe, + ) + assert encoder._fp8_recipe is None + assert encoder._fp4_recipe is fp4_recipe + assert all(encoder._layer_precision[i] == "fp4" for i in range(6)) + + def test_all_bf16(self, encoder): + encoder.initialize_quantization( + fp8_layers=None, + fp4_layers=None, + fp8_recipe=None, + fp4_recipe=None, + ) + assert all(encoder._layer_precision[i] is None for i in range(6)) + + def test_mixed_fp8_fp4(self, encoder): + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + encoder.initialize_quantization( + fp8_layers=[0, 1, 2], + fp4_layers=[3, 4, 5], + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + for i in range(3): + assert encoder._layer_precision[i] == "fp8" + for i in range(3, 6): + assert encoder._layer_precision[i] == "fp4" + + def test_mixed_fp8_bf16(self, encoder): + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + encoder.initialize_quantization( + fp8_layers=[0, 2, 4], + fp4_layers=None, + fp8_recipe=fp8_recipe, + fp4_recipe=None, + ) + assert encoder._layer_precision[0] == "fp8" + assert encoder._layer_precision[1] is None + assert encoder._layer_precision[2] == "fp8" + assert encoder._layer_precision[3] is None + assert encoder._layer_precision[4] == "fp8" + assert encoder._layer_precision[5] is None + + def test_mixed_all_three(self, encoder): + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + encoder.initialize_quantization( + fp8_layers=[0, 1], + fp4_layers=[4, 5], + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + assert encoder._layer_precision[0] == "fp8" + assert encoder._layer_precision[1] == "fp8" + assert encoder._layer_precision[2] is None # BF16 + assert encoder._layer_precision[3] is None # BF16 + assert encoder._layer_precision[4] == "fp4" + assert encoder._layer_precision[5] == "fp4" + + def test_empty_lists_treated_as_none(self, encoder): + encoder.initialize_quantization( + fp8_layers=[], + fp4_layers=[], + fp8_recipe=None, + fp4_recipe=None, + ) + assert all(encoder._layer_precision[i] is None for i in range(6)) + + def test_covers_all_layers(self, encoder): + encoder.initialize_quantization( + fp8_layers=[0], + fp4_layers=None, + fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), + fp4_recipe=None, + ) + assert len(encoder._layer_precision) == 6 + + def test_recipes_stored_as_attributes(self, encoder): + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + encoder.initialize_quantization( + fp8_layers=[0], + fp4_layers=[1], + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + # Recipes are stored once, not duplicated per-layer in the map. + assert encoder._fp8_recipe is fp8_recipe + assert encoder._fp4_recipe is fp4_recipe + # The map only contains strings, not recipe objects. + for v in encoder._layer_precision.values(): + assert v is None or isinstance(v, str) + + +class TestGetLayerAutocast: + """Tests for NVEsmEncoder.get_layer_autocast.""" + + def test_fp8_layer_returns_nullcontext(self, encoder): + encoder.initialize_quantization( + fp8_layers=[0], + fp4_layers=None, + fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), + fp4_recipe=None, + ) + ctx = encoder.get_layer_autocast(0) + assert isinstance(ctx, nullcontext) + + def test_fp4_layer_returns_te_autocast(self, encoder): + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + encoder.initialize_quantization( + fp8_layers=None, + fp4_layers=[0], + fp8_recipe=None, + fp4_recipe=fp4_recipe, + ) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "fp4_context" + ctx = encoder.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe) + assert ctx == "fp4_context" + + def test_bf16_layer_returns_te_autocast_disabled(self, encoder): + encoder.initialize_quantization( + fp8_layers=None, + fp4_layers=None, + fp8_recipe=None, + fp4_recipe=None, + ) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + ctx = encoder.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=False) + assert ctx == "bf16_context" + + def test_uninitialized_defaults_to_bf16(self, encoder): + """When initialize_quantization was never called, all layers default to BF16.""" + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + ctx = encoder.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=False) + assert ctx == "bf16_context" + + def test_mixed_layers_return_correct_contexts(self, encoder): + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + encoder.initialize_quantization( + fp8_layers=[0, 1], + fp4_layers=[2, 3], + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + # FP8 layers -> nullcontext + assert isinstance(encoder.get_layer_autocast(0), nullcontext) + assert isinstance(encoder.get_layer_autocast(1), nullcontext) + + # FP4 and BF16 layers -> te.pytorch.autocast (not nullcontext) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "fp4_context" + encoder.get_layer_autocast(2) + mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe) + + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + encoder.get_layer_autocast(4) + mock_autocast.assert_called_with(enabled=False) + + def test_layer_precision_map_is_pickleable(self, encoder): + """The _layer_precision map should be trivially pickleable (only strings/None).""" + import pickle + + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + encoder.initialize_quantization( + fp8_layers=[0, 1], + fp4_layers=[2, 3], + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + roundtripped = pickle.loads(pickle.dumps(encoder._layer_precision)) + assert roundtripped == encoder._layer_precision diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py index be4049df06..adf9921e12 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py @@ -22,11 +22,14 @@ Adapted from `modeling_esm.py` in huggingface/transformers. """ +from contextlib import nullcontext from typing import ClassVar, Literal, Optional, Unpack # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. import torch +import torch.cuda.nvtx as nvtx +import transformer_engine.common.recipe import transformer_engine.pytorch from torch import nn from torch.nn import CrossEntropyLoss @@ -54,6 +57,15 @@ "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", } +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = ( + transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling, +) +FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling + class NVEsmConfig(EsmConfig): """NVEsmConfig is a configuration for the NVEsm model.""" @@ -164,6 +176,9 @@ def _init_method(x): for i in range(config.num_hidden_layers) ] ) + self._fp8_recipe: object | None = None + self._fp4_recipe: object | None = None + self._layer_precision: dict[int, str | None] | None = None self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, @@ -173,6 +188,61 @@ def _init_method(x): if config.position_embedding_type == "rotary": self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + def initialize_quantization( + self, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, + fp8_recipe: object | None = None, + fp4_recipe: object | None = None, + ) -> None: + """Build the per-layer quantization precision map. + + Must be called after model creation and sharding (FSDP/DDP/mFSDP) but before training. + Each layer is tagged as ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). The recipe + objects are stored once on the encoder rather than duplicated per-layer, ensuring the + map is trivially pickleable. + + Args: + fp8_layers: 0-indexed layer numbers to run in FP8, or None. + fp4_layers: 0-indexed layer numbers to run in FP4, or None. + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() + self._layer_precision = {} + for layer_number in range(len(self.layers)): + if layer_number in fp8_layers_set: + self._layer_precision[layer_number] = "fp8" + elif layer_number in fp4_layers_set: + self._layer_precision[layer_number] = "fp4" + else: + self._layer_precision[layer_number] = None + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self._layer_precision.get(layer_number) if self._layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, hidden_states: torch.Tensor, @@ -198,24 +268,30 @@ def forward( te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) - for layer_module in self.layers: + # Per-layer quantization context (FP8, FP4, or BF16) is determined by get_layer_autocast(). + for layer_number, layer_module in enumerate(self.layers): if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - hidden_states = layer_module( - hidden_states, - attention_mask, - rotary_pos_emb=te_rope_emb, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) + nvtx.range_push(f"encoder_layer_{layer_number}") + with self.get_layer_autocast(layer_number): + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + nvtx.range_pop() # encoder_layer_N + nvtx.range_push("emb_layer_norm_after") hidden_states = self.emb_layer_norm_after(hidden_states) + nvtx.range_pop() # emb_layer_norm_after if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) diff --git a/bionemo-recipes/recipes/esm2_native_te/.dockerignore b/bionemo-recipes/recipes/esm2_native_te/.dockerignore index e67ca715ce..ff0577a466 100644 --- a/bionemo-recipes/recipes/esm2_native_te/.dockerignore +++ b/bionemo-recipes/recipes/esm2_native_te/.dockerignore @@ -1,10 +1,34 @@ +# Docker Dockerfile +Dockerfile.* +.dockerignore + +# Docs README.md -checkpoint_export/ -outputs/ -.ruff_cache + +# Python caches __pycache__ .pytest_cache -.ruff.toml -.dockerignore +.ruff_cache .venv/ + +# Linting +.ruff.toml + +# Profiling & debugging artifacts +memory_snapshots/ +nsight_profiling/ +*.nsys-rep +*.sqlite +logs/ +wandb/ + +# Hydra / training outputs +outputs/ +checkpoints/ + +# Checkpoint export +checkpoint_export/ + +# Temp / scratch +j/ diff --git a/bionemo-recipes/recipes/esm2_native_te/README.md b/bionemo-recipes/recipes/esm2_native_te/README.md index bb93a2ecd1..eb2e8163fb 100644 --- a/bionemo-recipes/recipes/esm2_native_te/README.md +++ b/bionemo-recipes/recipes/esm2_native_te/README.md @@ -1,7 +1,8 @@ # TransformerEngine-accelerated ESM-2 training with native PyTorch training loop This folder demonstrates how to train TE-accelerated ESM-2 with a native PyTorch training loop, including sequence -packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training. +packing, FP8/MXFP8/NVFP4 precision with layer-wise control, using fully sharded data parallel (FSDP) for distributed +training. ## How to use this recipe @@ -15,10 +16,10 @@ bionemo-framework repository. You can download a zipped directory of this folder ## Supported Models and Training Features -| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | -| ----------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | -| [ESM-2](../../models/esm2/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [AMPLIFY](../../models/amplify/README.md) | ✅ | ❌ | 🚧 | ❌ | ❌ | 🚧 | +| Model | BF16 | FP8[1] | MXFP8[2] | NVFP4[3] | THD Input Format | Context Parallelism | +| ----------------------------------------- | ---- | ----------------- | ------------------- | ------------------- | ---------------- | ------------------- | +| [ESM-2](../../models/esm2/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [AMPLIFY](../../models/amplify/README.md) | ✅ | ❌ | ❌ | ❌ | 🚧 | 🚧 | ✅: Supported
🚧: Under development
@@ -26,6 +27,7 @@ bionemo-framework repository. You can download a zipped directory of this folder \[1\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 9.0 and above (Hopper+)
\[2\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and 10.3 (Blackwell), 12.0 support pending
+\[3\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and above (Blackwell+)
### Installing Dependencies @@ -72,6 +74,37 @@ Recently, we measured 2800 tokens/second/GPU training speed on H100 with Hugging of THD sequence packing, however we have not been able to make this configuration work on Blackwell and this work is still in progress. +### Low precision performance benchmarks + +![Performance Benchmarks Low Precision](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png) +In the above plot, we can see that as we increase the scale of our models, the benefits of low precision training are apparent. +This is because at smaller parameter models (such as 650M, 3B) etc, the cost to quantize activations from high precision to low +precision outweights the benefits of performing matrix multiplication with low precision. However, as our models scale up in +parameter count, the fixed quantization cost is lower than our GEMM operational savings. + +Note: these plots were using our [fsdp2](./train_fsdp2.py) script. + +### Convergence results for low precision training + +#### MXFP8 + +![Convergence Benchmarks MXFP8](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg) +In the above plot, for our ESM2-15B model that was trained with FSDP2 using 80 B300 GPUs nodes for 10 hours. We can clearly see that +our MXFP8 run and our BF16 baseline run have the same results. A perfect match in convergence. + +#### NVFP4 + +![Convergence Benchmarks NVFP4](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg) +In the above plot, for our ESM2-15B model, we show several lines. Each experiment shows a unique configuration using a custom +amount of `fp4_layers` per run (more info on how to enable this below). Moreover, the experiments can be read as +`esm2-15b-b300-mxfp8-fp4_layer_start-fp4_layer_end-N-10-mbs-26-b300` which denotes at which point we start and end the fp4 layers. + +We see that as we add more and more layers, our E2E training results get worse. This is a tradeoff between performance and +accuracy. If we look at the performance chart above, we have increased performance dramatically, but our accuracy suffers. +It's important to note that in all NVFP4 experiments we are also utilizing stochastic rounding and random hadamard transformations. + +For more information regarding NVFP4 training please see [paper](https://arxiv.org/pdf/2509.25149) and [here](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) + ### Distributed Training This recipe supports distributed training using DDP, FSDP2, and Megatron-FSDP, shown in three separate training @@ -97,7 +130,7 @@ torchrun --nproc_per_node=2 train_fsdp2.py # or train_mfsdp.py / train_ddp.py Multi-Node training is supported with all three strategies, refer to [`slurm.sh`](slurm.sh) for an example SLURM script. -### FP8 Training +### Quantized Training (FP8 / MXFP8 / NVFP4) To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8 configuration parameters, including switching to `MXFP8BlockScaling`, can be set using the hydra configuration. @@ -106,26 +139,64 @@ configuration parameters, including switching to `MXFP8BlockScaling`, can be set python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true ``` -#### FP8 Debugging +Similarly, to train with NVFP4 quantization: -We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients. +```bash +python train_fsdp2.py --config-name L0_sanity fp4_config.enabled=true +``` -To enable this please select the following config options. +Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for +`train_mfsdp`. NVFP4 stats logging is not yet supported and will be enabled in a future TransformerEngine release; +FP8/MXFP8 stats logging works today. -```python +Additional recipe parameters (e.g., switching to `MXFP8BlockScaling`) can be set via the hydra configuration. + +#### Layer-Wise Precision + +You can control which transformer layers use FP8 or FP4 by specifying 1-indexed layer numbers via `fp8_layers` and +`fp4_layers`. Layers not assigned to either format will run in BF16. + +For example, to run layers 1-3 in FP8, layers 4-6 in FP4, and the rest in BF16 on a model with more than 6 layers: + +```bash +python train_fsdp2.py --config-name L0_sanity \ + fp8_config.enabled=true \ + fp4_config.enabled=true \ + 'fp8_layers=[1,2,3]' \ + 'fp4_layers=[4,5,6]' +``` + +When both `fp8_config` and `fp4_config` are enabled but only one layer list is provided, the other format automatically +claims the remaining layers. For example, if `fp8_layers=[1,2,3]` is set and `fp4_config.enabled=true` with no +`fp4_layers`, then layers 4 through N will default to FP4. + +#### Quantization Stats Debugging + +We provide a mechanism to log tensor statistics (activations, weights, gradients) for quantized layers during training. +When layer-wise precision is used, the stats config is automatically updated so that only the relevant layers are +tracked. + +To enable stats logging: + +```bash python train_fsdp2.py \ -fp8_stats_config.enabled=True # whether to log stats or not -fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy # where to store the logs -fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml # specifies what stats you want to run. Currently this is saved in this yaml file. -fp8_config.enabled=True # set this to use FP8 otherwise stats logging won't work + quant_stats_config.enabled=true \ + quant_stats_config.quant_log_dir=./logs/quant_stats \ + quant_stats_config.quant_stats_file=./fp8_debugging_stats.yaml \ + fp8_config.enabled=true ``` -Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for `train_mfsdp`. +Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for +`train_mfsdp`. NVFP4 stats logging is not yet supported and will be enabled in a future TransformerEngine release; +FP8/MXFP8 stats logging works today. -The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure. +The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the +[NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) +in more detail. -This comes as a performance cost that is dependent on the `freq` parameter mentioned above. `freq=1` collects stats on every step which in our -experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using `freq>=10` to reduce this performance hit. +Stats collection has a performance cost dependent on the `freq` parameter in the config file. `freq=1` collects stats +on every step which in our experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We +recommend using `freq>=10` to reduce this performance hit. ### Sequence Packing (THD input format) diff --git a/bionemo-recipes/recipes/esm2_native_te/dataset.py b/bionemo-recipes/recipes/esm2_native_te/dataset.py index fca23f3518..c853640aef 100644 --- a/bionemo-recipes/recipes/esm2_native_te/dataset.py +++ b/bionemo-recipes/recipes/esm2_native_te/dataset.py @@ -109,7 +109,6 @@ def create_bshd_dataloader( non-streaming datasets.Dataset. Defaults to True. use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state. mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking. - **kwargs: Unused, here to enable kwargs to match the signature of create_thd_dataloader. Returns: A dataloader that can be used for training. diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index be4049df06..adf9921e12 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -22,11 +22,14 @@ Adapted from `modeling_esm.py` in huggingface/transformers. """ +from contextlib import nullcontext from typing import ClassVar, Literal, Optional, Unpack # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. import torch +import torch.cuda.nvtx as nvtx +import transformer_engine.common.recipe import transformer_engine.pytorch from torch import nn from torch.nn import CrossEntropyLoss @@ -54,6 +57,15 @@ "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", } +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = ( + transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling, +) +FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling + class NVEsmConfig(EsmConfig): """NVEsmConfig is a configuration for the NVEsm model.""" @@ -164,6 +176,9 @@ def _init_method(x): for i in range(config.num_hidden_layers) ] ) + self._fp8_recipe: object | None = None + self._fp4_recipe: object | None = None + self._layer_precision: dict[int, str | None] | None = None self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, @@ -173,6 +188,61 @@ def _init_method(x): if config.position_embedding_type == "rotary": self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + def initialize_quantization( + self, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, + fp8_recipe: object | None = None, + fp4_recipe: object | None = None, + ) -> None: + """Build the per-layer quantization precision map. + + Must be called after model creation and sharding (FSDP/DDP/mFSDP) but before training. + Each layer is tagged as ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). The recipe + objects are stored once on the encoder rather than duplicated per-layer, ensuring the + map is trivially pickleable. + + Args: + fp8_layers: 0-indexed layer numbers to run in FP8, or None. + fp4_layers: 0-indexed layer numbers to run in FP4, or None. + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() + self._layer_precision = {} + for layer_number in range(len(self.layers)): + if layer_number in fp8_layers_set: + self._layer_precision[layer_number] = "fp8" + elif layer_number in fp4_layers_set: + self._layer_precision[layer_number] = "fp4" + else: + self._layer_precision[layer_number] = None + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self._layer_precision.get(layer_number) if self._layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, hidden_states: torch.Tensor, @@ -198,24 +268,30 @@ def forward( te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) - for layer_module in self.layers: + # Per-layer quantization context (FP8, FP4, or BF16) is determined by get_layer_autocast(). + for layer_number, layer_module in enumerate(self.layers): if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - hidden_states = layer_module( - hidden_states, - attention_mask, - rotary_pos_emb=te_rope_emb, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) + nvtx.range_push(f"encoder_layer_{layer_number}") + with self.get_layer_autocast(layer_number): + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + nvtx.range_pop() # encoder_layer_N + nvtx.range_push("emb_layer_norm_after") hidden_states = self.emb_layer_norm_after(hidden_states) + nvtx.range_pop() # emb_layer_norm_after if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml new file mode 100644 index 0000000000..d56739a6a6 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -0,0 +1,33 @@ +example_fp4_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 100 + - tensor: gradient + stats: [underflows%, mse] + freq: 100 + +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 + - tensor: gradient + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml index 7544bbedcf..ba640a6cbb 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection: enabled: True layers: # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv] + layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: LogFp8TensorStats: enabled: True @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection: - tensor: weight stats: [underflows%, scale_inv_min, scale_inv_max, mse] freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [dgrad, wgrad] + freq: 1 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 9e31c18880..b1cd4843b2 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -51,6 +51,12 @@ fp8_config: quantized_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + # Optimizer config adamw_kwargs: lr: 4e-4 @@ -76,9 +82,13 @@ checkpoint: logger: frequency: 100 -fp8_stats_config: + +quant_stats_config: enabled: false - fp8_stats_file: ./fp8_debugging_stats.yaml - fp8_log_dir: ./log_fp8_stats + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats -use_fp32_master_weights: false +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. +fp8_layers: null +fp4_layers: null +use_fp32_master_weights: null diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py new file mode 100644 index 0000000000..adf9921e12 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -0,0 +1,790 @@ +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +from contextlib import nullcontext +from typing import ClassVar, Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import torch.cuda.nvtx as nvtx +import transformer_engine.common.recipe +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs + + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = ( + transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling, +) +FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention: + "bshd" = Batch, Sequence, Head, Dimension (standard padded format) + "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) + Note that these formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for i in range(config.num_hidden_layers) + ] + ) + self._fp8_recipe: object | None = None + self._fp4_recipe: object | None = None + self._layer_precision: dict[int, str | None] | None = None + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def initialize_quantization( + self, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, + fp8_recipe: object | None = None, + fp4_recipe: object | None = None, + ) -> None: + """Build the per-layer quantization precision map. + + Must be called after model creation and sharding (FSDP/DDP/mFSDP) but before training. + Each layer is tagged as ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). The recipe + objects are stored once on the encoder rather than duplicated per-layer, ensuring the + map is trivially pickleable. + + Args: + fp8_layers: 0-indexed layer numbers to run in FP8, or None. + fp4_layers: 0-indexed layer numbers to run in FP4, or None. + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() + self._layer_precision = {} + for layer_number in range(len(self.layers)): + if layer_number in fp8_layers_set: + self._layer_precision[layer_number] = "fp8" + elif layer_number in fp4_layers_set: + self._layer_precision[layer_number] = "fp4" + else: + self._layer_precision[layer_number] = None + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self._layer_precision.get(layer_number) if self._layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + # Per-layer quantization context (FP8, FP4, or BF16) is determined by get_layer_autocast(). + for layer_number, layer_module in enumerate(self.layers): + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + nvtx.range_push(f"encoder_layer_{layer_number}") + with self.get_layer_autocast(layer_number): + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + nvtx.range_pop() # encoder_layer_N + + nvtx.range_push("emb_layer_norm_after") + hidden_states = self.emb_layer_norm_after(hidden_states) + nvtx.range_pop() # emb_layer_norm_after + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.esm.embeddings.word_embeddings.to_empty(device="cuda") + self.esm.embeddings.apply(self._init_weights) + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will + # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking + # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and + # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the + # weights are not in fp8. We still need to figure out why this raises an error if we're using + # `quantized_model_init`. + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. + These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. + """ + super().__init__(config) + self.config = config + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + with transformer_engine.pytorch.autocast(enabled=False): + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + # Keep the last layers of the network in higher precision to avoid numerical instability. + # Please see recipes/fp8_analysis/README.md for more details. + with transformer_engine.pytorch.autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def _apply_token_dropout_bshd(self, embeddings, input_ids, attention_mask): + """Apply token dropout scaling for BSHD-format inputs. + + Compensates for masked tokens by scaling unmasked embeddings based on the + observed mask ratio per sequence. + + Args: + embeddings: Token embeddings with masked positions already zeroed out. + input_ids: Original input token IDs. + attention_mask: Attention mask indicating valid tokens. + + Returns: + Scaled embeddings tensor. + """ + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + return (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + def _apply_token_dropout_thd(self, embeddings, input_ids, kwargs): + """Apply token dropout scaling for THD-format (packed sequence) inputs. + + Uses cumulative sequence lengths to compute per-sequence mask ratios and + scales embeddings accordingly using repeat_interleave. + + Args: + embeddings: Token embeddings with masked positions already zeroed out. + input_ids: Original input token IDs. + kwargs: Additional keyword arguments containing cu_seq_lens_q and optionally cu_seq_lens_q_padded. + + Returns: + Scaled embeddings tensor. + """ + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged(is_masked, offsets=kwargs["cu_seq_lens_q"]).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) + return (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + if using_thd: + embeddings = self._apply_token_dropout_thd(embeddings, input_ids, kwargs) + else: + embeddings = self._apply_token_dropout_bshd(embeddings, input_ids, attention_mask) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, + config.num_labels, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index f7a71b3e6e..2e67b3aaa5 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -77,7 +77,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self._progress_bar = tqdm(total=args.num_train_steps, desc="Training") # Whether to step debug_api.step() after each step - self.fp8_stats_enabled = args.fp8_stats_config.enabled + self.quant_stats_config = args.quant_stats_config.enabled def log_step( self, @@ -101,7 +101,7 @@ def log_step( if isinstance(grad_norm, DTensor): grad_norm = grad_norm.to_local() - if self.fp8_stats_enabled: + if self.quant_stats_config: debug_api.step() if step % self.logging_frequency == 0 and step > 0: @@ -152,11 +152,11 @@ def log_step( def finish(self): """Finish the logger and close the progress bar.""" + if self.quant_stats_config: + debug_api.end_debug() + if not self._dist_config.is_main_process(): return wandb.finish() self._progress_bar.close() - - if self.fp8_stats_enabled: - debug_api.end_debug() diff --git a/bionemo-recipes/recipes/esm2_native_te/quantization.py b/bionemo-recipes/recipes/esm2_native_te/quantization.py new file mode 100644 index 0000000000..0a8fe93a2f --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/quantization.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for layer-wise quantization configuration (FP8/FP4).""" + +import logging +import tempfile +from pathlib import Path + +import yaml + + +logger = logging.getLogger(__name__) + + +def generate_layer_regex(layer_numbers: list[int] | None) -> str: + """Generate a regex pattern to match specific layer numbers (1-indexed). + + The debug API (nvdlfw_inspect) uses 1-indexed layer names after ``infer_and_assign_layer_names``. + + Args: + layer_numbers: List of layer numbers (1-indexed, as shown in debug logs). + If empty or None, returns a pattern that matches nothing. + + Returns: + Regex pattern string for matching those layers' linear sublayers. + """ + if not layer_numbers: + return r"model\.esm\.encoder\.layers\.DISABLED_NO_LAYERS_SPECIFIED" + layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) + return rf"model\.esm\.encoder\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" + + +def update_quant_stats_config( + config_file: str, + fp4_layers: list[int] | None, + fp8_layers: list[int] | None, +) -> str: + """Update the quant stats YAML config with layer-specific regex patterns. + + Args: + config_file: Path to the original YAML config file. + fp4_layers: List of layer numbers for FP4 (1-indexed). + fp8_layers: List of layer numbers for FP8 (1-indexed). + + Returns: + Path to the updated config file (a temp file). + """ + with open(config_file, "r") as f: + config = yaml.safe_load(f) + + if "example_fp4_tensor_stat_collection" in config: + # TODO: Remove this block and replace with FP8-style regex update once a TransformerEngine + # release with LogNvfp4TensorStats support is available. At that point, this becomes: + # fp4_regex = generate_layer_regex(fp4_layers) + # config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex + config["example_fp4_tensor_stat_collection"]["enabled"] = False + if fp4_layers: + logger.warning( + "NVFP4 quant stats logging is not yet supported (requires a future TransformerEngine release). " + f"Disabling FP4 stats collection for layers {fp4_layers}. FP8 stats will still be collected." + ) + else: + logger.info("FP4 stats section disabled (no FP4 layers and feature not yet supported)") + + if "example_fp8_tensor_stat_collection" in config: + fp8_regex = generate_layer_regex(fp8_layers) + config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex + if fp8_layers: + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + else: + logger.info("FP8 layers empty - regex set to match nothing") + + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + yaml.dump(config, temp_file, default_flow_style=False) + temp_file.close() + + config_str = yaml.dump(config, default_flow_style=False) + logger.info(f"Created updated quant stats config at: {temp_file.name}") + logger.info(f"Updated quant stats config contents:\n{config_str}") + + return temp_file.name + + +def initialize_quant_stats_logging( + quant_stats_file: str, + quant_log_dir: str, + rank: int, + quant_layers: "QuantizationLayers", +) -> None: + """Set up quantization stats logging via nvdlfw_inspect. + + Updates the quant stats YAML config with resolved layer regex patterns, creates + the per-rank log directory, and initializes the debug API. + + Args: + quant_stats_file: Path to the base quant stats YAML config file. + quant_log_dir: Base directory for quant stats logs (a rank subdirectory will be created). + rank: The global rank of this process. + quant_layers: Resolved quantization layer assignments. + """ + import nvdlfw_inspect.api as debug_api + import transformer_engine + + updated_config = update_quant_stats_config( + config_file=quant_stats_file, + fp4_layers=quant_layers.fp4_layers_1indexed, + fp8_layers=quant_layers.fp8_layers_1indexed, + ) + + rank_log_dir = Path(quant_log_dir) / f"rank_{rank}" + rank_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {rank_log_dir}") + + te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") + debug_api.initialize( + config_file=updated_config, + feature_dirs=[te_features_dir], + log_dir=rank_log_dir, + default_logging_enabled=True, + ) + + +class QuantizationLayers: + """Resolved layer-wise quantization assignments. + + Attributes: + fp8_layers_0indexed: 0-indexed FP8 layer numbers (for model internals), or None. + fp4_layers_0indexed: 0-indexed FP4 layer numbers (for model internals), or None. + fp8_layers_1indexed: 1-indexed FP8 layer numbers (for user-facing logs / quant stats), or None. + fp4_layers_1indexed: 1-indexed FP4 layer numbers (for user-facing logs / quant stats), or None. + """ + + def __init__( + self, + fp8_layers_0indexed: list[int] | None, + fp4_layers_0indexed: list[int] | None, + fp8_layers_1indexed: list[int] | None, + fp4_layers_1indexed: list[int] | None, + ): + """Initialize QuantizationLayers with the resolved layer assignments.""" + self.fp8_layers_0indexed = fp8_layers_0indexed + self.fp4_layers_0indexed = fp4_layers_0indexed + self.fp8_layers_1indexed = fp8_layers_1indexed + self.fp4_layers_1indexed = fp4_layers_1indexed + + +def resolve_quantization_layers( + num_layers: int, + fp8_enabled: bool, + fp4_enabled: bool, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, +) -> QuantizationLayers: + """Resolve layer-wise quantization assignments from user config. + + Takes 1-indexed layer lists (as specified by the user) and returns both 0-indexed lists + (for model internals) and 1-indexed lists (for quant stats / debug logging). When a quantization + format is enabled but no layer list is provided, all layers default to that format. When one format + has explicit layers and the other is enabled without a layer list, the unspecified format defaults + to the remaining (unclaimed) layers. + + Args: + num_layers: Total number of transformer layers in the model. + fp8_enabled: Whether FP8 quantization is enabled. + fp4_enabled: Whether FP4 quantization is enabled. + fp8_layers: 1-indexed list of layers for FP8, or None if not specified. + fp4_layers: 1-indexed list of layers for FP4, or None if not specified. + + Returns: + QuantizationLayers with both 0-indexed and 1-indexed layer lists. + + Raises: + ValueError: If both formats are enabled with no layer lists, or if layer lists overlap. + """ + all_layers = set(range(1, num_layers + 1)) + + if fp8_enabled and fp4_enabled and fp8_layers is None and fp4_layers is None: + raise ValueError( + "Both fp8_config and fp4_config are enabled but neither fp8_layers nor fp4_layers is specified. " + "When both are enabled, you must explicitly provide layer lists to indicate which layers use which format." + ) + + # When one format has explicit layers and the other defaults, fill in the remaining layers. + if fp8_enabled and fp8_layers is None: + claimed_by_fp4 = set(fp4_layers) if fp4_layers is not None else set() + fp8_layers = sorted(all_layers - claimed_by_fp4) + if claimed_by_fp4: + logger.warning( + f"fp8_config.enabled=True with no fp8_layers specified, but fp4_layers={sorted(claimed_by_fp4)} " + f"are already claimed by FP4. Defaulting FP8 to the remaining layers: {fp8_layers}" + ) + else: + logger.info( + f"fp8_config.enabled=True with no fp8_layers specified, defaulting all {num_layers} layers to FP8" + ) + + if fp4_enabled and fp4_layers is None: + claimed_by_fp8 = set(fp8_layers) if fp8_layers is not None else set() + fp4_layers = sorted(all_layers - claimed_by_fp8) + if claimed_by_fp8: + logger.warning( + f"fp4_config.enabled=True with no fp4_layers specified, but fp8_layers={sorted(claimed_by_fp8)} " + f"are already claimed by FP8. Defaulting FP4 to the remaining layers: {fp4_layers}" + ) + else: + logger.info( + f"fp4_config.enabled=True with no fp4_layers specified, defaulting all {num_layers} layers to FP4" + ) + + # Disable layer lists when corresponding config is not enabled. + if not fp8_enabled: + fp8_layers = None + if not fp4_enabled: + fp4_layers = None + + # Validate no overlap between FP8 and FP4 layer assignments. + if fp8_layers is not None and fp4_layers is not None: + overlap = set(fp8_layers) & set(fp4_layers) + if overlap: + raise ValueError( + f"fp8_layers and fp4_layers cannot have overlapping layer numbers. Found overlap: {sorted(overlap)}" + ) + + return QuantizationLayers( + fp8_layers_0indexed=[layer - 1 for layer in fp8_layers] if fp8_layers is not None else None, + fp4_layers_0indexed=[layer - 1 for layer in fp4_layers] if fp4_layers is not None else None, + fp8_layers_1indexed=fp8_layers, + fp4_layers_1indexed=fp4_layers, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/requirements.txt b/bionemo-recipes/recipes/esm2_native_te/requirements.txt index b18607fd7a..96753ffb2a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_native_te/requirements.txt @@ -8,6 +8,6 @@ torchdata torchmetrics tqdm transformer_engine[pytorch] -transformers +transformers>=5.0.0 wandb nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_distributed_checkpointing.py index 6d931474cd..a9919055ff 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_distributed_checkpointing.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_distributed_checkpointing.py @@ -866,7 +866,6 @@ def test_final_model_save_ddp(recipe_path, tmp_path): Validates that DDP saves the final model correctly with: - model.safetensors containing weights - config.json with model configuration - - esm_nv.py for custom model code """ temp_dir = str(tmp_path / "test_final_ddp") @@ -888,7 +887,7 @@ def test_final_model_save_ddp(recipe_path, tmp_path): assert os.path.exists(final_model_dir), "Final model directory not created" # Check required files - required_files = ["model.safetensors", "config.json", "esm_nv.py"] + required_files = ["model.safetensors", "config.json"] for file in required_files: file_path = os.path.join(final_model_dir, file) assert os.path.exists(file_path), f"Missing required file: {file}" @@ -901,7 +900,6 @@ def test_final_model_save_mfsdp(recipe_path, tmp_path): Validates that mFSDP gathers parameters and saves the final model with: - model.safetensors containing gathered weights - config.json with model configuration - - esm_nv.py for custom model code """ temp_dir = str(tmp_path / "test_final_mfsdp") @@ -924,7 +922,7 @@ def test_final_model_save_mfsdp(recipe_path, tmp_path): assert os.path.exists(final_model_dir), "Final model directory not created" # Check required files - required_files = ["model.safetensors", "config.json", "esm_nv.py"] + required_files = ["model.safetensors", "config.json"] for file in required_files: file_path = os.path.join(final_model_dir, file) assert os.path.exists(file_path), f"Missing required file: {file}" diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_quantization.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_quantization.py new file mode 100644 index 0000000000..6d1b09e6ff --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_quantization.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +from pathlib import Path + +import pytest +import yaml + + +sys.path.append(Path(__file__).parent.parent.as_posix()) + +from quantization import generate_layer_regex, resolve_quantization_layers, update_quant_stats_config + + +class TestResolveQuantizationLayers: + """Tests for resolve_quantization_layers().""" + + def test_fp8_enabled_no_layers_defaults_all(self): + """When fp8 is enabled with no explicit layers, all layers should default to FP8.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result.fp8_layers_0indexed == [0, 1, 2, 3, 4, 5] + assert result.fp8_layers_1indexed == [1, 2, 3, 4, 5, 6] + assert result.fp4_layers_0indexed is None + assert result.fp4_layers_1indexed is None + + def test_fp4_enabled_no_layers_defaults_all(self): + """When fp4 is enabled with no explicit layers, all layers should default to FP4.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=None + ) + assert result.fp8_layers_0indexed is None + assert result.fp4_layers_0indexed == [0, 1, 2, 3, 4, 5] + assert result.fp4_layers_1indexed == [1, 2, 3, 4, 5, 6] + + def test_fp8_explicit_layers(self): + """Explicit 1-indexed fp8_layers should be converted to 0-indexed.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[1, 3, 5], fp4_layers=None + ) + assert result.fp8_layers_0indexed == [0, 2, 4] + assert result.fp8_layers_1indexed == [1, 3, 5] + assert result.fp4_layers_0indexed is None + + def test_fp4_explicit_layers(self): + """Explicit 1-indexed fp4_layers should be converted to 0-indexed.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=[2, 4, 6] + ) + assert result.fp8_layers_0indexed is None + assert result.fp4_layers_0indexed == [1, 3, 5] + assert result.fp4_layers_1indexed == [2, 4, 6] + + def test_mixed_fp8_fp4_explicit(self): + """Both enabled with explicit non-overlapping layers should work correctly.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 3, 4], fp4_layers=[2, 5] + ) + assert result.fp8_layers_0indexed == [0, 2, 3] + assert result.fp8_layers_1indexed == [1, 3, 4] + assert result.fp4_layers_0indexed == [1, 4] + assert result.fp4_layers_1indexed == [2, 5] + + def test_both_enabled_no_layers_raises(self): + """Both enabled with no layer lists should raise ValueError.""" + with pytest.raises(ValueError, match="Both fp8_config and fp4_config are enabled"): + resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=None + ) + + def test_overlapping_layers_raises(self): + """Overlapping layer assignments should raise ValueError.""" + with pytest.raises(ValueError, match="fp8_layers and fp4_layers cannot have overlapping"): + resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=[3, 4, 5] + ) + + def test_disabled_ignores_layers(self): + """When a format is disabled, its layers should be None even if provided.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=[1, 2, 3], fp4_layers=[4, 5, 6] + ) + assert result.fp8_layers_0indexed is None + assert result.fp8_layers_1indexed is None + assert result.fp4_layers_0indexed is None + assert result.fp4_layers_1indexed is None + + def test_both_disabled(self): + """Both disabled with no layers should return all None.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result.fp8_layers_0indexed is None + assert result.fp4_layers_0indexed is None + + def test_large_model_defaults_all(self): + """Auto-population should work correctly for larger models (e.g. 36 layers).""" + result = resolve_quantization_layers( + num_layers=36, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result.fp8_layers_0indexed == list(range(36)) + assert result.fp8_layers_1indexed == list(range(1, 37)) + + def test_fp8_enabled_empty_list(self): + """An explicit empty list should remain empty (not default to all).""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[], fp4_layers=None + ) + assert result.fp8_layers_0indexed == [] + assert result.fp8_layers_1indexed == [] + + def test_both_enabled_fp8_specified_fp4_defaults_to_remaining(self): + """When both enabled, FP8 has explicit layers, FP4 should default to the remaining layers.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=None + ) + assert result.fp8_layers_0indexed == [0, 1, 2] + assert result.fp8_layers_1indexed == [1, 2, 3] + assert result.fp4_layers_0indexed == [3, 4, 5] + assert result.fp4_layers_1indexed == [4, 5, 6] + + def test_both_enabled_fp4_specified_fp8_defaults_to_remaining(self): + """When both enabled, FP4 has explicit layers, FP8 should default to the remaining layers.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=[4, 5, 6] + ) + assert result.fp8_layers_0indexed == [0, 1, 2] + assert result.fp8_layers_1indexed == [1, 2, 3] + assert result.fp4_layers_0indexed == [3, 4, 5] + assert result.fp4_layers_1indexed == [4, 5, 6] + + +class TestGenerateLayerRegex: + """Tests for generate_layer_regex().""" + + def test_single_layer(self): + """Single layer should produce a simple regex.""" + regex = generate_layer_regex([3]) + assert re.search(regex, "model.esm.encoder.layers.3.self_attention.layernorm_qkv") + assert not re.search(regex, "model.esm.encoder.layers.2.self_attention.layernorm_qkv") + + def test_multiple_layers(self): + """Multiple layers should match any of them.""" + regex = generate_layer_regex([1, 2, 3]) + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv") + assert re.search(regex, "model.esm.encoder.layers.2.layernorm_mlp.fc1") + assert re.search(regex, "model.esm.encoder.layers.3.layernorm_mlp.fc2") + assert not re.search(regex, "model.esm.encoder.layers.4.self_attention.proj") + + def test_matches_correct_sublayers(self): + """Regex should only match layernorm_qkv, proj, fc1, fc2.""" + regex = generate_layer_regex([1]) + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv_something") + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.proj_something") + assert re.search(regex, "model.esm.encoder.layers.1.layernorm_mlp.fc1_something") + assert re.search(regex, "model.esm.encoder.layers.1.layernorm_mlp.fc2_something") + # Should not match unrelated sublayer names + assert not re.search(regex, "model.esm.encoder.layers.1.self_attention.some_other_thing") + + def test_none_returns_disabled_pattern(self): + """None should return a pattern that matches nothing.""" + regex = generate_layer_regex(None) + assert "DISABLED" in regex + assert not re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv") + + def test_empty_list_returns_disabled_pattern(self): + """Empty list should return a pattern that matches nothing.""" + regex = generate_layer_regex([]) + assert "DISABLED" in regex + + def test_1indexed_layer_names(self): + """Regex should use 1-indexed layer numbers (matching debug API naming).""" + regex = generate_layer_regex([1]) + # Should match layers.1 (1-indexed first layer) + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv") + # Should NOT match layers.0 (0-indexed first layer) + assert not re.search(regex, "model.esm.encoder.layers.0.self_attention.layernorm_qkv") + + +class TestUpdateQuantStatsConfig: + """Tests for update_quant_stats_config().""" + + @pytest.fixture + def fp8_only_config(self, tmp_path): + """Create an FP8-only stats config file.""" + config = { + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": { + "enabled": True, + "tensors_struct": [{"tensor": "activation", "stats": ["underflows%"], "freq": 10}], + } + }, + } + } + config_path = tmp_path / "fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + @pytest.fixture + def fp4_fp8_config(self, tmp_path): + """Create a combined FP4+FP8 stats config file.""" + config = { + "example_fp4_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogNvfp4TensorStats": {"enabled": True}, + }, + }, + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": {"enabled": True}, + }, + }, + } + config_path = tmp_path / "fp4_fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + def test_fp8_layers_updates_regex(self, fp8_only_config): + """FP8 layer list should update the regex in the output config.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2, 3]) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv") + assert re.search(regex, "model.esm.encoder.layers.3.layernorm_mlp.fc2") + assert not re.search(regex, "model.esm.encoder.layers.4.self_attention.proj") + + def test_none_layers_disables_matching(self, fp8_only_config): + """None layers should set regex to match nothing.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=None) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert "DISABLED" in regex + + def test_fp4_section_disabled_fp8_still_updated(self, fp4_fp8_config): + """FP4 stats section should be disabled (not yet supported), FP8 should still be updated.""" + output_path = update_quant_stats_config(config_file=fp4_fp8_config, fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6]) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should be disabled + assert result["example_fp4_tensor_stat_collection"]["enabled"] is False + + # FP8 regex should still match layers 4-6 + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "model.esm.encoder.layers.5.self_attention.proj") + assert not re.search(fp8_regex, "model.esm.encoder.layers.2.self_attention.proj") + + def test_original_file_not_modified(self, fp8_only_config): + """update_quant_stats_config should write to a temp file, not modify the original.""" + with open(fp8_only_config) as f: + original_content = f.read() + + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2]) + + assert output_path != fp8_only_config + with open(fp8_only_config) as f: + assert f.read() == original_content + + def test_preserves_other_config_fields(self, fp8_only_config): + """Non-layer fields in the config should be preserved.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1]) + with open(output_path) as f: + result = yaml.safe_load(f) + # The transformer_engine section should still be there + assert ( + result["example_fp8_tensor_stat_collection"]["transformer_engine"]["LogFp8TensorStats"]["enabled"] is True + ) + + def test_missing_section_is_skipped(self, fp8_only_config): + """If fp4 section doesn't exist in config, it should be silently skipped.""" + # fp8_only_config has no fp4 section — passing fp4_layers should not error + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=[1, 2], fp8_layers=[3, 4]) + with open(output_path) as f: + result = yaml.safe_load(f) + # Only FP8 section should exist and be updated + assert "example_fp4_tensor_stat_collection" not in result + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(regex, "model.esm.encoder.layers.3.self_attention.layernorm_qkv") + + def test_with_real_fp4_config(self): + """Test with the actual fp4_debugging_stats.yaml file.""" + config_path = Path(__file__).parent.parent / "fp4_debugging_stats.yaml" + if not config_path.exists(): + pytest.skip("fp4_debugging_stats.yaml not found") + + output_path = update_quant_stats_config( + config_file=str(config_path), fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6] + ) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should be disabled (not yet supported in current TE release) + assert result["example_fp4_tensor_stat_collection"]["enabled"] is False + + # FP8 section should still be updated and working + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "model.esm.encoder.layers.5.self_attention.proj") + assert not re.search(fp8_regex, "model.esm.encoder.layers.2.self_attention.proj") diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_train.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_train.py index ec86ed86f6..298bf26858 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_train.py @@ -154,8 +154,8 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path): f"+wandb_init_args.dir={tmp_path}", f"checkpoint.ckpt_dir={tmp_path}", "fp8_config.enabled=true", - "fp8_stats_config.enabled=true", - f"fp8_stats_config.fp8_log_dir={fp8_log_dir}", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={fp8_log_dir}", "num_train_steps=4", ], ) @@ -211,8 +211,8 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): f"+wandb_init_args.dir={tmp_path}", f"checkpoint.ckpt_dir={tmp_path}", "fp8_config.enabled=true", - "fp8_stats_config.enabled=true", - f"fp8_stats_config.fp8_log_dir={fp8_log_dir}", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={fp8_log_dir}", "num_train_steps=4", ], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 168d25b57a..de43047d63 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -21,17 +21,17 @@ import torch import transformer_engine import transformer_engine.pytorch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh from torch.optim import AdamW from transformer_engine.common.recipe import Format -from transformers import AutoConfig, AutoModelForMaskedLM from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from fp8_debugging import initialize_fp8_debugging +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_quantization_layers from scheduler import get_linear_schedule_with_warmup @@ -53,21 +53,46 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - MUST be done BEFORE FSDP wrapping - if args.fp8_stats_config.enabled: - initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) + if args.use_fp32_master_weights: + raise ValueError("FP32 master weights are not supported with DDP. Use train_fsdp2.py instead.") + + # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) + num_layers = config.num_hidden_layers + + # Resolve layer-wise quantization assignments. + quant_layers = resolve_quantization_layers( + num_layers=num_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + quant_layers=quant_layers, + ) # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2 # and MFSDP. device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("ddp",)) - # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) - # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -78,10 +103,18 @@ def main(args: DictConfig) -> float | None: with transformer_engine.pytorch.quantized_model_init( recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) + # Initialize per-layer quantization on the encoder. + model.esm.encoder.initialize_quantization( + fp8_layers=quant_layers.fp8_layers_0indexed, + fp4_layers=quant_layers.fp4_layers_0indexed, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. try: @@ -93,7 +126,7 @@ def main(args: DictConfig) -> float | None: optimizer = AdamW(model.parameters(), **args.adamw_kwargs) scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) model = model.to(device=device) @@ -139,7 +172,9 @@ def main(args: DictConfig) -> float | None: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901 # Forward pass with mixed precision. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + with transformer_engine.pytorch.autocast( + enabled=args.fp8_config.enabled, recipe=fp8_recipe if args.fp8_config.enabled else None + ): outputs = model(**batch) # Backward pass. diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index c5a8dad34d..0f3a74e959 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -19,16 +19,17 @@ import hydra import torch import transformer_engine.pytorch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh from torch.optim import AdamW from transformer_engine.common.recipe import Format -from transformers import AutoConfig, AutoModelForMaskedLM from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_cp_dataloader from distributed_config import DistributedConfig +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger +from quantization import resolve_quantization_layers from scheduler import get_linear_schedule_with_warmup @@ -74,15 +75,31 @@ def main(args: DictConfig) -> float | None: mesh_dim_names=("ddp", "cp"), ) - # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) + fp4_recipe = None + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + if args.use_fp32_master_weights: + raise ValueError("FP32 master weights are not supported with DDP+CP. Use train_fsdp2_cp.py instead.") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". # Note: token_dropout is set to False because it's not compatible with context parallelism. - config = AutoConfig.from_pretrained( - args.model_tag, trust_remote_code=True, token_dropout=False, dtype=torch.bfloat16 + config = NVEsmConfig.from_pretrained(args.model_tag, token_dropout=False, dtype=torch.bfloat16) + num_layers = config.num_hidden_layers + + # Resolve layer-wise quantization assignments. + quant_layers = resolve_quantization_layers( + num_layers=num_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, ) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: @@ -94,10 +111,18 @@ def main(args: DictConfig) -> float | None: with transformer_engine.pytorch.quantized_model_init( recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) + # Initialize per-layer quantization on the encoder. + model.esm.encoder.initialize_quantization( + fp8_layers=quant_layers.fp8_layers_0indexed, + fp4_layers=quant_layers.fp4_layers_0indexed, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + # Create optimizer. optimizer = AdamW(model.parameters(), **args.adamw_kwargs) scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index ff72d5fce7..66963ed2c5 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import os from contextlib import nullcontext from pathlib import Path @@ -27,16 +28,13 @@ from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from torch.optim import AdamW from transformer_engine.common.recipe import Format -from transformers import AutoConfig, AutoModelForMaskedLM - -# This import seems to be needed with meta device init and AutoModel.from_config -from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401 from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from fp8_debugging import initialize_fp8_debugging +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_quantization_layers from scheduler import get_linear_schedule_with_warmup @@ -51,6 +49,9 @@ def main(args: DictConfig) -> float | None: Returns: float: The loss value for the final batch. """ + os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "1" + logging.getLogger("httpx").setLevel(logging.WARNING) + # Initialize the distributed configuration, including creating the distributed process group. dist_config = DistributedConfig() logger.info("Initializing distributed training: %s", dist_config) @@ -58,9 +59,27 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - MUST be done BEFORE FSDP wrapping - if args.fp8_stats_config.enabled: - initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) + # Load model config early so we know the number of layers for auto-populating layer lists. + config = NVEsmConfig.from_pretrained( + args.model_tag, dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16 + ) + num_layers = config.num_hidden_layers + + # Resolve layer-wise quantization assignments. + quant_layers = resolve_quantization_layers( + num_layers=num_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + quant_layers=quant_layers, + ) # Create a device mesh for FSDP. device_mesh = init_device_mesh( @@ -69,15 +88,18 @@ def main(args: DictConfig) -> float | None: mesh_dim_names=("dp",), ) - # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. + fp8_recipe = None + fp4_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) - # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained( - args.model_tag, trust_remote_code=True, dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16 - ) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -91,7 +113,7 @@ def main(args: DictConfig) -> float | None: recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs ), ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) @@ -103,6 +125,7 @@ def main(args: DictConfig) -> float | None: param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward reduce_dtype=torch.float32, # Gradient reductions in FP32 output_dtype=torch.bfloat16, # Forward output dtype + cast_forward_inputs=False, ) else: mp_policy = MixedPrecisionPolicy() @@ -110,6 +133,14 @@ def main(args: DictConfig) -> float | None: fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + # Initialize per-layer quantization on the encoder. + model.esm.encoder.initialize_quantization( + fp8_layers=quant_layers.fp8_layers_0indexed, + fp4_layers=quant_layers.fp4_layers_0indexed, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. # Note, this should happen before we create the optimizer. if args.use_meta_device: @@ -121,11 +152,12 @@ def main(args: DictConfig) -> float | None: model.apply(model._init_weights) # Assign names to layers so debug API can identify them - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + # Note: Got an error about mixed torch.Tensor and DTensor here, so using AdamW instead. scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. @@ -162,20 +194,23 @@ def main(args: DictConfig) -> float | None: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 - # Forward pass with mixed precision. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + # --- Forward pass --- + with transformer_engine.pytorch.autocast( + enabled=args.fp8_config.enabled, recipe=fp8_recipe if args.fp8_config.enabled else None + ): outputs = model(**batch) - # Backward pass. + # --- Backward pass --- loss = outputs.loss loss.backward() - # Compute and clip gradient norms. - total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + # --- Grad clip --- + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() - # Step optimizer. + # --- Optimizer step --- optimizer.step() scheduler.step() + optimizer.zero_grad() perf_logger.log_step( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 6a824cc9fa..bd3d2a2337 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -25,15 +25,13 @@ from torch.distributed.fsdp import fully_shard from torch.optim import AdamW from transformer_engine.common.recipe import Format -from transformers import AutoConfig, AutoModelForMaskedLM - -# This import seems to be needed with meta device init and AutoModel.from_config -from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401 from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from dataset import create_cp_dataloader from distributed_config import DistributedConfig +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger +from quantization import resolve_quantization_layers from scheduler import get_linear_schedule_with_warmup @@ -86,14 +84,30 @@ def main(args: DictConfig) -> float | None: f"Creating device mesh: world_size={dist_config.world_size}, dp_size={dp_size}, cp_size={args.cp_size}" ) - # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) + fp4_recipe = None + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + if args.use_fp32_master_weights: + raise ValueError("FP32 master weights are not supported with FSDP2+CP. Use train_fsdp2.py instead.") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained( - args.model_tag, trust_remote_code=True, token_dropout=False, dtype=torch.bfloat16 + config = NVEsmConfig.from_pretrained(args.model_tag, token_dropout=False, dtype=torch.bfloat16) + num_layers = config.num_hidden_layers + + # Resolve layer-wise quantization assignments. + quant_layers = resolve_quantization_layers( + num_layers=num_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, ) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: @@ -108,7 +122,7 @@ def main(args: DictConfig) -> float | None: recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs ), ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) @@ -128,6 +142,14 @@ def main(args: DictConfig) -> float | None: ) fully_shard(model, mesh=cp_dp_mesh) + # Initialize per-layer quantization on the encoder. + model.esm.encoder.initialize_quantization( + fp8_layers=quant_layers.fp8_layers_0indexed, + fp4_layers=quant_layers.fp4_layers_0indexed, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. # Note, this should happen before we create the optimizer. if args.use_meta_device: diff --git a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py index ad32b6171b..178b5840c7 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py @@ -25,12 +25,13 @@ from torch.distributed.device_mesh import init_device_mesh from torch.optim import AdamW from transformer_engine.common.recipe import Format -from transformers import AutoConfig, AutoModelForMaskedLM from checkpoint import load_checkpoint_mfsdp, save_checkpoint_mfsdp, save_final_model_mfsdp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM from perf_logger import PerfLogger +from quantization import resolve_quantization_layers from scheduler import get_linear_schedule_with_warmup @@ -65,13 +66,32 @@ def main(args: DictConfig) -> float | None: mesh_dim_names=("dp", "tp"), ) - # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) + fp4_recipe = None + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + if args.use_fp32_master_weights: + raise ValueError("FP32 master weights are not supported with mFSDP. Use train_fsdp2.py instead.") # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16) + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.bfloat16) + num_layers = config.num_hidden_layers + + # Resolve layer-wise quantization assignments. + quant_layers = resolve_quantization_layers( + num_layers=num_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -82,10 +102,18 @@ def main(args: DictConfig) -> float | None: with transformer_engine.pytorch.quantized_model_init( recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) + # Initialize per-layer quantization on the encoder. + model.esm.encoder.initialize_quantization( + fp8_layers=quant_layers.fp8_layers_0indexed, + fp4_layers=quant_layers.fp4_layers_0indexed, + fp8_recipe=fp8_recipe, + fp4_recipe=fp4_recipe, + ) + # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py index be4049df06..adf9921e12 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py @@ -22,11 +22,14 @@ Adapted from `modeling_esm.py` in huggingface/transformers. """ +from contextlib import nullcontext from typing import ClassVar, Literal, Optional, Unpack # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. import torch +import torch.cuda.nvtx as nvtx +import transformer_engine.common.recipe import transformer_engine.pytorch from torch import nn from torch.nn import CrossEntropyLoss @@ -54,6 +57,15 @@ "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", } +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = ( + transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling, +) +FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling + class NVEsmConfig(EsmConfig): """NVEsmConfig is a configuration for the NVEsm model.""" @@ -164,6 +176,9 @@ def _init_method(x): for i in range(config.num_hidden_layers) ] ) + self._fp8_recipe: object | None = None + self._fp4_recipe: object | None = None + self._layer_precision: dict[int, str | None] | None = None self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, @@ -173,6 +188,61 @@ def _init_method(x): if config.position_embedding_type == "rotary": self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + def initialize_quantization( + self, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, + fp8_recipe: object | None = None, + fp4_recipe: object | None = None, + ) -> None: + """Build the per-layer quantization precision map. + + Must be called after model creation and sharding (FSDP/DDP/mFSDP) but before training. + Each layer is tagged as ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). The recipe + objects are stored once on the encoder rather than duplicated per-layer, ensuring the + map is trivially pickleable. + + Args: + fp8_layers: 0-indexed layer numbers to run in FP8, or None. + fp4_layers: 0-indexed layer numbers to run in FP4, or None. + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() + self._layer_precision = {} + for layer_number in range(len(self.layers)): + if layer_number in fp8_layers_set: + self._layer_precision[layer_number] = "fp8" + elif layer_number in fp4_layers_set: + self._layer_precision[layer_number] = "fp4" + else: + self._layer_precision[layer_number] = None + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self._layer_precision.get(layer_number) if self._layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, hidden_states: torch.Tensor, @@ -198,24 +268,30 @@ def forward( te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) - for layer_module in self.layers: + # Per-layer quantization context (FP8, FP4, or BF16) is determined by get_layer_autocast(). + for layer_number, layer_module in enumerate(self.layers): if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - hidden_states = layer_module( - hidden_states, - attention_mask, - rotary_pos_emb=te_rope_emb, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) + nvtx.range_push(f"encoder_layer_{layer_number}") + with self.get_layer_autocast(layer_number): + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + nvtx.range_pop() # encoder_layer_N + nvtx.range_push("emb_layer_norm_after") hidden_states = self.emb_layer_norm_after(hidden_states) + nvtx.range_pop() # emb_layer_norm_after if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg new file mode 100644 index 0000000000..613f343af9 --- /dev/null +++ b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg @@ -0,0 +1,118 @@ +
5001k1.5k2kStep1313.51414.51515.516
diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg new file mode 100644 index 0000000000..8f0f49a386 --- /dev/null +++ b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg @@ -0,0 +1,118 @@ +
5001k1.5k2kStep1313.51414.51515.516
diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-mxfp8-6node-conv.png b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-mxfp8-6node-conv.png new file mode 100644 index 0000000000..2a71d80f98 Binary files /dev/null and b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-mxfp8-6node-conv.png differ diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-nvfp4-6node-conv.png b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-nvfp4-6node-conv.png new file mode 100644 index 0000000000..4766891471 Binary files /dev/null and b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-nvfp4-6node-conv.png differ diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png b/docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png new file mode 100644 index 0000000000..d89a4e6158 Binary files /dev/null and b/docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png differ