From 371e02bf6c56889dd122f93a1f09cf3ae1ac0c36 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Fri, 18 Jul 2025 12:30:25 -0400 Subject: [PATCH 1/2] Make sure Moshi is exportable with static cache --- src/transformers/integrations/executorch.py | 45 +++++++--- .../models/moshi/modeling_moshi.py | 1 + tests/models/moshi/test_modeling_moshi.py | 85 +++++++++++++++++++ 3 files changed, 119 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 71777d123cda..44b73e0ab06c 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -367,23 +367,44 @@ def generate( break response_tokens = [] - for input_pos in range(min(max_generation_length, prompt_token_len)): + + # Process the first token (which the model was exported with) + if prompt_token_len > 0: result = exported_program.module().forward( - input_ids=prompt_token_ids[:, input_pos : input_pos + 1], - cache_position=torch.tensor([input_pos], dtype=torch.long, device=device), + input_ids=prompt_token_ids[:, 0:1], + cache_position=torch.tensor([0], dtype=torch.long, device=device), ) - response_tokens.append(prompt_token_ids[0][input_pos].item()) + response_tokens.append(prompt_token_ids[0][0].item()) + + # For remaining prompt tokens, we need to process them one by one + # but start from position 1 since position 0 was used during export + for input_pos in range(1, min(max_generation_length, prompt_token_len)): + result = exported_program.module().forward( + input_ids=prompt_token_ids[:, input_pos : input_pos + 1], + cache_position=torch.tensor([input_pos], dtype=torch.long, device=device), + ) + response_tokens.append(prompt_token_ids[0][input_pos].item()) - current_token = torch.argmax(result[:, -1, :], dim=-1).item() - response_tokens.append(current_token) + # Generate new tokens starting from the correct position + current_position = len(response_tokens) + for _ in range(max_new_tokens): + if current_position >= max_generation_length: + break + + # For the first generation step, use the last processed token's output + if current_position == prompt_token_len: + # Use the result from the last prompt token processing + current_token = torch.argmax(result[:, -1, :], dim=-1).item() + else: + # Generate subsequent tokens + result = exported_program.module().forward( + input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device), + cache_position=torch.tensor([current_position], dtype=torch.long, device=device), + ) + current_token = torch.argmax(result[:, -1, :], dim=-1).item() - while len(response_tokens) < max_generation_length: - result = exported_program.module().forward( - input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device), - cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device), - ) - current_token = torch.argmax(result[:, -1, :], dim=-1).item() response_tokens.append(current_token) + current_position += 1 return torch.tensor([response_tokens], dtype=torch.long, device=device) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 94a3a1fa1fe3..1f845a1b83b8 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -807,6 +807,7 @@ class MoshiPreTrainedModel(PreTrainedModel): _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"] _supports_flash_attn = True _supports_sdpa = True + _supports_static_cache = True main_input_name = "input_ids" diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index 4f5b1689594e..dca5b18eeadf 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -371,6 +371,91 @@ def test_generate_continue_from_inputs_embeds(self): def test_save_load(self): super().test_save_load() + @slow + def test_export_static_cache(self): + from packaging import version + + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.generation.configuration_utils import GenerationConfig + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + + # Create a small model for testing + config = self.model_tester.get_config() + config.use_cache = True + config.attn_implementation = "sdpa" + + # Create model with static cache generation config + model = MoshiForCausalLM(config).to(torch_device) + model.eval() + + # Set up generation config with static cache + batch_size = 1 + max_generation_length = 50 + model.generation_config = GenerationConfig( + use_cache=True, + cache_implementation="static", + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + "device": torch_device, + }, + ) + + # Test exportable module with static cache + from transformers.integrations.executorch import ( + TorchExportableModuleForDecoderOnlyLM, + ) + + # Create exportable module + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + + # Get representative token IDs within model's vocabulary range + # Use simple token IDs that are within the test model's vocab_size (99) + prompt_token_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long, device=torch_device) + + # Use first token for example input + example_input_ids = prompt_token_ids[:, :1] + example_cache_position = torch.tensor([0], dtype=torch.long, device=torch_device) + + # Export the model + exported_program = exportable_module.export( + input_ids=example_input_ids, + cache_position=example_cache_position, + ) + + # Generate reference output from eager model + with torch.no_grad(): + eager_generated_ids = model.generate( + prompt_token_ids, + max_new_tokens=5, + do_sample=False, + use_cache=True, + ) + + # Test generation with exported program + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, + prompt_token_ids=prompt_token_ids, + max_new_tokens=5, + ) + + # Verify the exported program generates tokens + self.assertIsInstance(ep_generated_ids, torch.Tensor) + self.assertEqual(ep_generated_ids.shape[0], 1) # batch size + self.assertGreater(ep_generated_ids.shape[1], 1) # generated tokens + + # Compare exported model output with eager model output + self.assertEqual(ep_generated_ids.shape, eager_generated_ids.shape) + + # Note: Due to numerical precision differences in export, we use relaxed tolerances + # The key validation is that both models generate tokens and have the same shape + torch.testing.assert_close(ep_generated_ids, eager_generated_ids, rtol=1e-2, atol=1e-2) + class MoshiTester: def __init__( From ff1ac470323845a346a9e1dc9379f0a822e314b0 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 31 Jul 2025 15:58:15 -0400 Subject: [PATCH 2/2] Add multimodal support to ExecuTorch integration This commit enhances the ExecuTorch integration to support multimodal models like Gemma-3, LLaVA, and other vision-language models. Key changes: - Enhanced TorchExportableModuleWithHybridCache to support inputs_embeds parameter and multimodal configs - Added TorchExportableModuleForImageTextLM for image-text language models - Added ImageEncoderExportableModule for vision encoders - Added a test for multimodal functionality This enables ExecuTorch export for vision-language models while maintaining backward compatibility with text-only models. --- src/transformers/integrations/executorch.py | 259 +++++++++++++++--- .../models/moshi/modeling_moshi.py | 1 - .../test_executorch_multimodal.py | 151 ++++++++++ tests/models/moshi/test_modeling_moshi.py | 85 ------ 4 files changed, 370 insertions(+), 126 deletions(-) create mode 100644 tests/integrations/test_executorch_multimodal.py diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 44b73e0ab06c..52b37cc840be 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -367,44 +367,23 @@ def generate( break response_tokens = [] - - # Process the first token (which the model was exported with) - if prompt_token_len > 0: + for input_pos in range(min(max_generation_length, prompt_token_len)): result = exported_program.module().forward( - input_ids=prompt_token_ids[:, 0:1], - cache_position=torch.tensor([0], dtype=torch.long, device=device), + input_ids=prompt_token_ids[:, input_pos : input_pos + 1], + cache_position=torch.tensor([input_pos], dtype=torch.long, device=device), ) - response_tokens.append(prompt_token_ids[0][0].item()) - - # For remaining prompt tokens, we need to process them one by one - # but start from position 1 since position 0 was used during export - for input_pos in range(1, min(max_generation_length, prompt_token_len)): - result = exported_program.module().forward( - input_ids=prompt_token_ids[:, input_pos : input_pos + 1], - cache_position=torch.tensor([input_pos], dtype=torch.long, device=device), - ) - response_tokens.append(prompt_token_ids[0][input_pos].item()) - - # Generate new tokens starting from the correct position - current_position = len(response_tokens) - for _ in range(max_new_tokens): - if current_position >= max_generation_length: - break + response_tokens.append(prompt_token_ids[0][input_pos].item()) - # For the first generation step, use the last processed token's output - if current_position == prompt_token_len: - # Use the result from the last prompt token processing - current_token = torch.argmax(result[:, -1, :], dim=-1).item() - else: - # Generate subsequent tokens - result = exported_program.module().forward( - input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device), - cache_position=torch.tensor([current_position], dtype=torch.long, device=device), - ) - current_token = torch.argmax(result[:, -1, :], dim=-1).item() + current_token = torch.argmax(result[:, -1, :], dim=-1).item() + response_tokens.append(current_token) + while len(response_tokens) < max_generation_length: + result = exported_program.module().forward( + input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device), + cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device), + ) + current_token = torch.argmax(result[:, -1, :], dim=-1).item() response_tokens.append(current_token) - current_position += 1 return torch.tensor([response_tokens], dtype=torch.long, device=device) @@ -436,13 +415,16 @@ def __init__( super().__init__() self.model = model + # For multimodal models, use text_config if available + config = getattr(self.model.config, 'text_config', self.model.config) + # Verify the model is configured for HybridCache - if not self.model.config.use_cache: + if not config.use_cache: raise AssertionError("Model must have caching enabled") # Initialize the HybridCache self.cache = HybridCache( - config=self.model.config, + config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=self.model.device, @@ -456,20 +438,31 @@ def __init__( def forward( self, - input_ids: torch.Tensor, - cache_position: torch.Tensor, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the module, which is compatible with the ExecuTorch llm runner. Args: - input_ids (`torch.Tensor`): Tensor representing current input token id to the module. - cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + input_ids (`torch.Tensor`, *optional*): + Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`, *optional*): + Tensor representing input embeddings. Used for multimodal models. + cache_position (`torch.Tensor`, *optional*): + Tensor representing current input position in the cache. Returns: torch.Tensor: Logits output from the model. """ - batch_size = input_ids.shape[0] + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if cache_position is None: + raise ValueError("cache_position is required") + + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] # Generate position_ids from cache_position position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) @@ -477,6 +470,7 @@ def forward( # Forward pass with the model outputs = self.model( input_ids=input_ids, + inputs_embeds=inputs_embeds, attention_mask=None, position_ids=position_ids, past_key_values=self.cache, @@ -874,3 +868,188 @@ def sdpa_mask_without_vmap( if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) return causal_mask + + +class TorchExportableModuleForImageTextLM(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for image-text LM with cache. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module for image-text models. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + + Raises: + ValueError: If the model is configured with an unsupported cache implementation. + """ + super().__init__() + + if not hasattr(model.config, "text_config") or not hasattr(model.config.text_config, "use_cache") or model.config.text_config.use_cache is False: + raise ValueError("The model must have caching enabled to be performant.") + + if hasattr(model.config.text_config, "layer_types") and getattr(model.config.text_config, "sliding_window", None) is not None: + self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + else: + # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, + # there is only 1 type of layers, so export will use `StaticCache` by default. + logging.info( + "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." + ) + self.model = TorchExportableModuleWithStaticCache(model) + + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + self.model.model.config._attn_implementation = "sdpa_without_vmap" + + def forward( + self, + inputs_embeds: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + + Args: + inputs_embeds (`torch.Tensor`): Tensor representing input embeddings. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.model.forward(inputs_embeds=inputs_embeds, cache_position=cache_position) + + def export( + self, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, + ) -> torch.export.ExportedProgram: + """ + Export the wrapped module using `torch.export`. + + Args: + inputs_embeds (`Optional[torch.Tensor]`): + Tensor representing input embeddings. If not provided, a default tensor will be used. + cache_position (`Optional[torch.Tensor]`): + Tensor representing current input position in the cache. If not provided, a default tensor will be used. + dynamic_shapes (`Optional[dict]`): + Dynamic shapes to use for export if specified. + strict(`Optional[bool]`): + Flag to instruct `torch.export` to use `torchdynamo`. + """ + if hasattr(self.model, "base_model_prefix"): + base = getattr(self.model, self.model.base_model_prefix, self.model) + model_device = base.device + elif hasattr(self.model, "model"): + model_device = self.model.model.device + else: + model_device = "cpu" + logging.warning( + "TorchExportableModuleForImageTextLM.export Can't infer device from the model. Set to CPU by default." + ) + + seq_length = 3 + hidden_size = self.model.model.config.text_config.hidden_size if hasattr(self.model.model.config, 'text_config') else self.model.model.config.hidden_size + + example_inputs_embeds = ( + inputs_embeds if inputs_embeds is not None + else torch.zeros(1, seq_length, hidden_size, dtype=torch.float32, device=model_device) + ) + example_cache_position = ( + cache_position if cache_position is not None + else torch.arange(seq_length, dtype=torch.long, device=model_device) + ) + + if dynamic_shapes is None: + seq_len_dim = torch.export.Dim("seq_length_dim", max=seq_length) + dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + exported_program = torch.export.export( + self.model, + args=(), + kwargs={"inputs_embeds": example_inputs_embeds, "cache_position": example_cache_position}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + return exported_program + + +class ImageEncoderExportableModule(torch.nn.Module): + """ + A wrapper module designed to make a vision encoder-only model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, pixel_values): + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`. + """ + vision_outputs = self.model.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.model.multi_modal_projector(vision_outputs) + return image_features + + def export( + self, + pixel_values: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, + ) -> torch.export.ExportedProgram: + """ + Export the vision encoder using `torch.export`. + + Args: + pixel_values (`Optional[torch.Tensor]`): + Input images tensor. If not provided, a default tensor will be used. + dynamic_shapes (`Optional[dict]`): + Dynamic shapes to use for export if specified. + strict(`Optional[bool]`): + Flag to instruct `torch.export` to use `torchdynamo`. + """ + if hasattr(self.model, "vision_tower") and hasattr(self.model.vision_tower, "config"): + image_size = self.model.vision_tower.config.image_size + num_channels = getattr(self.model.vision_tower.config, "num_channels", 3) + else: + # Default values for vision models + image_size = 224 + num_channels = 3 + + example_pixel_values = ( + pixel_values if pixel_values is not None + else torch.randn(1, num_channels, image_size, image_size, dtype=torch.float32) + ) + + exported_program = torch.export.export( + self, + args=(example_pixel_values,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else False, + ) + return exported_program diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 1f845a1b83b8..94a3a1fa1fe3 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -807,7 +807,6 @@ class MoshiPreTrainedModel(PreTrainedModel): _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"] _supports_flash_attn = True _supports_sdpa = True - _supports_static_cache = True main_input_name = "input_ids" diff --git a/tests/integrations/test_executorch_multimodal.py b/tests/integrations/test_executorch_multimodal.py new file mode 100644 index 000000000000..b624e70e23bb --- /dev/null +++ b/tests/integrations/test_executorch_multimodal.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock, patch + +import torch + +from transformers import HfArgumentParser +from transformers.integrations.executorch import ( + ImageEncoderExportableModule, + TorchExportableModuleForImageTextLM, + TorchExportableModuleWithHybridCache, +) +from transformers.testing_utils import require_torch + + +@require_torch +class ExecuTorchMultimodalTest(unittest.TestCase): + def setUp(self): + # Mock multimodal model configuration + self.mock_config = Mock() + self.mock_config.text_config = Mock() + self.mock_config.text_config.use_cache = True + self.mock_config.text_config.hidden_size = 768 + self.mock_config.text_config.num_hidden_layers = 12 + self.mock_config.vision_config = Mock() + self.mock_config.vision_config.image_size = 224 + + # Mock model + self.mock_model = Mock() + self.mock_model.config = self.mock_config + self.mock_model.device = torch.device("cpu") + self.mock_model.dtype = torch.float32 + + def test_hybrid_cache_inputs_embeds_support(self): + """Test that TorchExportableModuleWithHybridCache supports inputs_embeds""" + with patch("transformers.integrations.executorch.HybridCache") as MockCache: + # Create exportable module + exportable = TorchExportableModuleWithHybridCache(self.mock_model) + + # Test forward with inputs_embeds + batch_size, seq_len, hidden_size = 1, 3, 768 + inputs_embeds = torch.randn(batch_size, seq_len, hidden_size) + cache_position = torch.arange(seq_len) + + # Mock model output + mock_output = Mock() + mock_output.logits = torch.randn(batch_size, seq_len, 32000) # vocab_size + self.mock_model.return_value = mock_output + + # Call forward + result = exportable.forward(inputs_embeds=inputs_embeds, cache_position=cache_position) + + # Verify model was called with inputs_embeds + self.mock_model.assert_called_once() + call_kwargs = self.mock_model.call_args[1] + self.assertIn("inputs_embeds", call_kwargs) + self.assertIsNone(call_kwargs["input_ids"]) + torch.testing.assert_close(call_kwargs["inputs_embeds"], inputs_embeds) + + def test_hybrid_cache_multimodal_config(self): + """Test that TorchExportableModuleWithHybridCache uses text_config for multimodal models""" + with patch("transformers.integrations.executorch.HybridCache") as MockCache: + # Create exportable module + exportable = TorchExportableModuleWithHybridCache(self.mock_model) + + # Verify HybridCache was initialized with text_config + MockCache.assert_called_once() + call_args = MockCache.call_args[1] + self.assertEqual(call_args["config"], self.mock_config.text_config) + + def test_image_text_lm_module(self): + """Test TorchExportableModuleForImageTextLM initialization""" + with patch("transformers.integrations.executorch.TorchExportableModuleWithHybridCache") as MockWrapper: + with patch("transformers.integrations.executorch.ALL_MASK_ATTENTION_FUNCTIONS"): + with patch("transformers.integrations.executorch.ALL_ATTENTION_FUNCTIONS"): + # Create image-text LM module + exportable = TorchExportableModuleForImageTextLM(self.mock_model) + + # Verify it creates the appropriate wrapper + MockWrapper.assert_called_once_with(self.mock_model, 1, 4096) + + def test_image_encoder_module(self): + """Test ImageEncoderExportableModule""" + # Mock vision model + mock_vision_tower = Mock() + mock_vision_outputs = Mock() + mock_vision_outputs.last_hidden_state = torch.randn(1, 196, 768) # 14x14 patches + mock_vision_tower.return_value = mock_vision_outputs + + mock_projector = Mock() + mock_projector.return_value = torch.randn(1, 196, 768) # projected features + + mock_model = Mock() + mock_model.vision_tower = mock_vision_tower + mock_model.multi_modal_projector = mock_projector + + # Create encoder module + encoder = ImageEncoderExportableModule(mock_model) + + # Test forward pass + pixel_values = torch.randn(1, 3, 224, 224) + result = encoder.forward(pixel_values) + + # Verify calls + mock_vision_tower.assert_called_once_with(pixel_values=pixel_values) + mock_projector.assert_called_once_with(mock_vision_outputs.last_hidden_state) + + def test_error_handling(self): + """Test error handling for invalid configurations""" + # Test missing cache configuration + bad_config = Mock() + bad_config.text_config = Mock() + bad_config.text_config.use_cache = False + + bad_model = Mock() + bad_model.config = bad_config + + with self.assertRaises(ValueError): + TorchExportableModuleForImageTextLM(bad_model) + + def test_forward_validation(self): + """Test input validation in forward method""" + with patch("transformers.integrations.executorch.HybridCache"): + exportable = TorchExportableModuleWithHybridCache(self.mock_model) + + # Test missing both input_ids and inputs_embeds + with self.assertRaises(ValueError): + exportable.forward(cache_position=torch.tensor([0])) + + # Test missing cache_position + with self.assertRaises(ValueError): + exportable.forward(input_ids=torch.tensor([[1]])) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index dca5b18eeadf..4f5b1689594e 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -371,91 +371,6 @@ def test_generate_continue_from_inputs_embeds(self): def test_save_load(self): super().test_save_load() - @slow - def test_export_static_cache(self): - from packaging import version - - if version.parse(torch.__version__) < version.parse("2.4.0"): - self.skipTest(reason="This test requires torch >= 2.4 to run.") - - from transformers.generation.configuration_utils import GenerationConfig - from transformers.integrations.executorch import ( - TorchExportableModuleWithStaticCache, - ) - - # Create a small model for testing - config = self.model_tester.get_config() - config.use_cache = True - config.attn_implementation = "sdpa" - - # Create model with static cache generation config - model = MoshiForCausalLM(config).to(torch_device) - model.eval() - - # Set up generation config with static cache - batch_size = 1 - max_generation_length = 50 - model.generation_config = GenerationConfig( - use_cache=True, - cache_implementation="static", - max_length=max_generation_length, - cache_config={ - "batch_size": batch_size, - "max_cache_len": max_generation_length, - "device": torch_device, - }, - ) - - # Test exportable module with static cache - from transformers.integrations.executorch import ( - TorchExportableModuleForDecoderOnlyLM, - ) - - # Create exportable module - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - - # Get representative token IDs within model's vocabulary range - # Use simple token IDs that are within the test model's vocab_size (99) - prompt_token_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long, device=torch_device) - - # Use first token for example input - example_input_ids = prompt_token_ids[:, :1] - example_cache_position = torch.tensor([0], dtype=torch.long, device=torch_device) - - # Export the model - exported_program = exportable_module.export( - input_ids=example_input_ids, - cache_position=example_cache_position, - ) - - # Generate reference output from eager model - with torch.no_grad(): - eager_generated_ids = model.generate( - prompt_token_ids, - max_new_tokens=5, - do_sample=False, - use_cache=True, - ) - - # Test generation with exported program - ep_generated_ids = TorchExportableModuleWithStaticCache.generate( - exported_program=exported_program, - prompt_token_ids=prompt_token_ids, - max_new_tokens=5, - ) - - # Verify the exported program generates tokens - self.assertIsInstance(ep_generated_ids, torch.Tensor) - self.assertEqual(ep_generated_ids.shape[0], 1) # batch size - self.assertGreater(ep_generated_ids.shape[1], 1) # generated tokens - - # Compare exported model output with eager model output - self.assertEqual(ep_generated_ids.shape, eager_generated_ids.shape) - - # Note: Due to numerical precision differences in export, we use relaxed tolerances - # The key validation is that both models generate tokens and have the same shape - torch.testing.assert_close(ep_generated_ids, eager_generated_ids, rtol=1e-2, atol=1e-2) - class MoshiTester: def __init__(