-
Notifications
You must be signed in to change notification settings - Fork 123
ESM2 changes to work with vLLM #1473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e415bb9
3f9f00e
99e8087
7383a7c
af0c43c
e2b3fcd
c34c09b
a67df14
8e8a87c
36cdbb2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,11 @@ def export_hf_checkpoint(tag: str, export_path: Path): | |
| model_hf_masked_lm = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}") | ||
| model_hf = AutoModel.from_pretrained(f"facebook/{tag}") | ||
| model_hf_masked_lm.esm.pooler = model_hf.pooler | ||
| model_te = convert_esm_hf_to_te(model_hf_masked_lm) | ||
|
|
||
| # Export without vocab padding so the checkpoint stores embeddings at the real | ||
| # vocab_size. This avoids shape-mismatch errors in vLLM's VocabParallelEmbedding, | ||
| # which expects vocab_size-shaped weights. | ||
| model_te = convert_esm_hf_to_te(model_hf_masked_lm, padded_vocab_size=None) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [not blocking]: Okay made some changes to convert_esm_hf_to_te |
||
| model_te.save_pretrained(export_path / tag) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") # Use our PreTrainedTokenizerFast implementation. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,7 @@ def __init__( | |
| max_seq_length: Optional[int] = None, | ||
| padded_vocab_size: Optional[int] = 64, | ||
| attn_mask_type: str = "padding", | ||
| add_pooling_layer: bool = False, | ||
| **kwargs, | ||
| ): | ||
| """Initialize the NVEsmConfig with additional TE-related config options. | ||
|
|
@@ -100,6 +101,9 @@ def __init__( | |
| padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults | ||
| to vocab_size. Must be greater than or equal to vocab_size. | ||
| attn_mask_type: The type of attention mask to use. | ||
| add_pooling_layer: Whether the base model should include a pooling layer. | ||
| Defaults to ``False`` because exported checkpoints do not contain pooler | ||
| weights. Set to ``True`` only if you have a checkpoint with pooler weights. | ||
| **kwargs: Additional config options to pass to EsmConfig. | ||
| """ | ||
| super().__init__(**kwargs) | ||
|
|
@@ -111,6 +115,7 @@ def __init__( | |
| self.micro_batch_size = micro_batch_size | ||
| self.max_seq_length = max_seq_length | ||
| self.attn_mask_type = attn_mask_type | ||
| self.add_pooling_layer = add_pooling_layer | ||
|
|
||
| # Set padded_vocab_size with default fallback to vocab_size | ||
| self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size | ||
|
|
@@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): | |
| """An abstract class to handle weights initialization and pretrained model loading.""" | ||
|
|
||
| config_class = NVEsmConfig | ||
| base_model_prefix = "esm" | ||
| base_model_prefix = "model" | ||
| supports_gradient_checkpointing = False | ||
| accepts_loss_kwargs = False | ||
| _no_split_modules = ( | ||
|
|
@@ -247,11 +252,11 @@ def init_empty_weights(self): | |
| if hasattr(module, "reset_parameters"): | ||
| module.reset_parameters() | ||
|
|
||
| # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use | ||
| # The embeddings layer is the only non-TE layer in this model we need to deal with. We use | ||
| # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard | ||
| # deviation. | ||
| self.esm.embeddings.word_embeddings.to_empty(device="cuda") | ||
| self.esm.embeddings.apply(self._init_weights) | ||
| # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. | ||
| self.base_model.embeddings.word_embeddings.to_empty(device="cuda") | ||
| self.base_model.embeddings.apply(self._init_weights) | ||
|
|
||
| # Meta-device init seems to break weight tying, so we re-tie the weights here. | ||
| self.tie_weights() | ||
|
|
@@ -276,14 +281,16 @@ def _init_weights(self, module): | |
| super()._init_weights(module) | ||
|
|
||
| def state_dict(self, *args, **kwargs): | ||
| """Override state_dict to filter out TransformerEngine's _extra_state keys. | ||
| """Override state_dict to filter out non-loadable keys. | ||
|
|
||
| TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. | ||
| These are filtered out to ensure checkpoints can be loaded with from_pretrained(). | ||
| Filters out: | ||
| - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. | ||
| - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed | ||
| in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates | ||
| over ``named_parameters``, not ``named_buffers``). | ||
| """ | ||
| state_dict = super().state_dict(*args, **kwargs) | ||
| # Filter out _extra_state keys which are TransformerEngine-specific and not loadable | ||
| return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} | ||
| return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} | ||
|
|
||
|
|
||
| class NVEsmModel(NVEsmPreTrainedModel): | ||
|
|
@@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): | |
| This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. | ||
| """ | ||
|
|
||
| def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): | ||
| def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): | ||
| """Initialize a NVEsmModel. | ||
|
|
||
| Args: | ||
| config (NVEsmConfig): The configuration of the model. | ||
| add_pooling_layer (bool): Whether to add a pooling layer. | ||
| add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, | ||
| reads ``config.add_pooling_layer`` (defaults to ``True``). | ||
| """ | ||
| super().__init__(config) | ||
| self.config = config | ||
|
|
||
| if add_pooling_layer is None: | ||
| add_pooling_layer = getattr(config, "add_pooling_layer", True) | ||
|
|
||
| # Ensure pad_token_id is set properly, defaulting to 0 if not specified | ||
| if not hasattr(config, "pad_token_id") or config.pad_token_id is None: | ||
| config.pad_token_id = 0 | ||
|
|
@@ -391,8 +402,10 @@ def forward( | |
| class NVEsmForMaskedLM(NVEsmPreTrainedModel): | ||
| """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" | ||
|
|
||
| _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} | ||
| _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're deleting |
||
| _tied_weights_keys: ClassVar[dict[str, str]] = { | ||
| "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" | ||
| } | ||
| _do_not_quantize = ("lm_head.dense", "lm_head.decoder") | ||
|
|
||
| def __init__(self, config: NVEsmConfig): | ||
| """Initialize a NVEsmForMaskedLM. | ||
|
|
@@ -408,7 +421,7 @@ def __init__(self, config: NVEsmConfig): | |
| "bi-directional self-attention." | ||
| ) | ||
|
|
||
| self.esm = NVEsmModel(config, add_pooling_layer=False) | ||
| self.model = NVEsmModel(config, add_pooling_layer=False) | ||
| self.lm_head = NVEsmLMHead(config) | ||
|
|
||
| self.post_init() | ||
|
|
@@ -443,7 +456,7 @@ def forward( | |
| Returns: | ||
| MaskedLMOutput: The output of the model. | ||
| """ | ||
| outputs = self.esm( | ||
| outputs = self.model( | ||
| input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
|
|
@@ -633,7 +646,7 @@ def __init__(self, config): | |
| super().__init__(config) | ||
| self.num_labels = config.num_labels | ||
|
|
||
| self.esm = NVEsmModel(config, add_pooling_layer=False) | ||
| self.model = NVEsmModel(config, add_pooling_layer=False) | ||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||
| self.classifier = transformer_engine.pytorch.Linear( | ||
| config.hidden_size, | ||
|
|
@@ -659,7 +672,7 @@ def forward( | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
| Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. | ||
| """ | ||
| outputs = self.esm( | ||
| outputs = self.model( | ||
| input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,3 +9,4 @@ torch | |
| torchao!=0.14.0 | ||
| transformer_engine[pytorch] | ||
| transformers | ||
| vllm | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [not blocking]: Better to pin, this dep may update
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I think it is a recipes design choice to avoid pinning unless absolutely necessary to maximize compatibility when installing into external environments. @pstjohn thoughts?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, please don't pin. When we get uv working, we can have a |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -209,8 +209,8 @@ def test_context_parallel_equivalence_2process(): | |
|
|
||
| # Sample gradients from a few layers for comparison | ||
| sample_layers = [ | ||
| model.esm.encoder.layers[0].self_attention.core_attention, | ||
| model.esm.encoder.layers[0].self_attention.layernorm_qkv, | ||
| model.model.encoder.layers[0].self_attention.core_attention, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why the rename here from |
||
| model.model.encoder.layers[0].self_attention.layernorm_qkv, | ||
| ] | ||
|
|
||
| # Now grab the gradients from the sample layers | ||
|
|
@@ -262,7 +262,7 @@ def test_context_parallel_equivalence_2process(): | |
| cp_world_size = torch.distributed.get_world_size(group=cp_group) | ||
|
|
||
| # Set up context parallelism for each layer | ||
| for i, transformer_layer in enumerate(model.module.esm.encoder.layers): | ||
| for i, transformer_layer in enumerate(model.module.model.encoder.layers): | ||
| transformer_layer.set_context_parallel_group( | ||
| cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() | ||
| ) | ||
|
|
@@ -347,8 +347,8 @@ def test_context_parallel_equivalence_2process(): | |
| # Capture gradients from the same layers in the CP model | ||
| # Note: DDP wraps the model with 'module.' prefix | ||
| sample_layers_cp = [ | ||
| model.module.esm.encoder.layers[0].self_attention.core_attention, | ||
| model.module.esm.encoder.layers[0].self_attention.layernorm_qkv, | ||
| model.module.model.encoder.layers[0].self_attention.core_attention, | ||
| model.module.model.encoder.layers[0].self_attention.layernorm_qkv, | ||
| ] | ||
|
|
||
| gradients_cp = {} | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
README.md
-top-level Dockerfile WORKDIR is /workspace/bionemo2, recommend the same for this Dockerfile for consistency