From 760919c108cb15c9160f8ee9022362f636fef4af Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 20 Aug 2025 07:04:40 -0500 Subject: [PATCH 1/9] Relaxed assumptions on cache_config --- src/transformers/integrations/executorch.py | 31 +++++++++++++-------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 49b666912246..c6ae2435f803 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -659,16 +659,25 @@ def __init__( "The model must have a generation config to be exported with static caching. " "Please set `generation_config` in `model`." ) - if "batch_size" not in generation_config.cache_config: - raise ValueError( - "The model's generation config must specify a batch_size in its cache_config. " - 'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)' - ) - if "max_cache_len" not in generation_config.cache_config: - raise ValueError( - "The model's generation config must specify a max_cache_len in its cache_config. " - 'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)' - ) + if generation_config.cache_config is None: + logging.warning("The model has no cache_config, using default cache_config.") + max_batch_size = 1 + max_cache_len = 4096 + device = self.model.device + else: + max_batch_size = generation_config.cache_config.get("batch_size", None) + if max_batch_size is None: + logging.warning("The model's cache_config has no attribute batch_size, using 1 as default.") + max_batch_size = 1 + max_cache_len = generation_config.cache_config.get("max_cache_len", None) + if max_cache_len is None: + logging.warning("The model's cache_config has no attribute max_cache_len, using 4096 as default.") + max_cache_len = 4096 + device = generation_config.cache_config.get("device", None) + if device is None: + logging.warning("The model's cache_config has no attribute device, using model.device as default.") + device = self.model.device + if not config.use_cache: raise AssertionError("Model must have caching enabled.") @@ -676,8 +685,6 @@ def __init__( self.cache = StaticCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len")) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - max_batch_size = generation_config.cache_config.get("batch_size") - device = generation_config.cache_config.get("device") dtype = self.model.dtype # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable) self.cache.early_initialization(max_batch_size, num_heads, head_dim, dtype, device) From fb0b6804e202e0840eae80de8382a8160eda423c Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 20 Aug 2025 10:47:19 -0500 Subject: [PATCH 2/9] Review compliance --- src/transformers/integrations/executorch.py | 45 ++++++--------------- 1 file changed, 13 insertions(+), 32 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index c6ae2435f803..b350353be525 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -494,16 +494,6 @@ def __init__( "The model must have a generation config to be exported with static caching. " "Please set `generation_config` in `model`." ) - if "batch_size" not in generation_config.cache_config: - raise ValueError( - "The model's generation config must specify a batch_size in its cache_config. " - 'Try GenerationConfig( ... cache_config={"batch_size": 1, ...} ...)' - ) - if "max_cache_len" not in generation_config.cache_config: - raise ValueError( - "The model's generation config must specify a max_cache_len in its cache_config. " - 'Try GenerationConfig( ... cache_config={"max_cache_len": 4096, ...} ...)' - ) if not generation_config.use_cache: raise AssertionError( "The model must have caching enabled to be exported with static caching. " @@ -515,15 +505,20 @@ def __init__( "Please set `generation_config.cache_implementation='static'`." ) + + cache_config = {} if generation_config.cache_config is None else generation_config.cache_config + batch_size = cache_config.get("batch_size", 1) + max_cache_len = cache_config.get("max_cache_len", 4096) + device = cache_config.get("device", self.model.device) + self.model = model self.static_cache = StaticCache( - max_cache_len=generation_config.cache_config.get("max_cache_len"), + max_cache_len=max_cache_len, config=config, ) - batch_size = generation_config.cache_config.get("batch_size") head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - device = generation_config.cache_config.get("device") + device = device dtype = self.model.dtype # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable) self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device) @@ -659,24 +654,10 @@ def __init__( "The model must have a generation config to be exported with static caching. " "Please set `generation_config` in `model`." ) - if generation_config.cache_config is None: - logging.warning("The model has no cache_config, using default cache_config.") - max_batch_size = 1 - max_cache_len = 4096 - device = self.model.device - else: - max_batch_size = generation_config.cache_config.get("batch_size", None) - if max_batch_size is None: - logging.warning("The model's cache_config has no attribute batch_size, using 1 as default.") - max_batch_size = 1 - max_cache_len = generation_config.cache_config.get("max_cache_len", None) - if max_cache_len is None: - logging.warning("The model's cache_config has no attribute max_cache_len, using 4096 as default.") - max_cache_len = 4096 - device = generation_config.cache_config.get("device", None) - if device is None: - logging.warning("The model's cache_config has no attribute device, using model.device as default.") - device = self.model.device + cache_config = {} if generation_config.cache_config is None else generation_config.cache_config + batch_size = cache_config.get("batch_size", 1) + max_cache_len = cache_config.get("max_cache_len", 4096) + device = cache_config.get("device", self.model.device) if not config.use_cache: raise AssertionError("Model must have caching enabled.") @@ -687,7 +668,7 @@ def __init__( num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) dtype = self.model.dtype # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable) - self.cache.early_initialization(max_batch_size, num_heads, head_dim, dtype, device) + self.cache.early_initialization(batch_size, num_heads, head_dim, dtype, device) # Register all key and value cache tensors as buffers for i in range(len(self.cache)): From 80632ff8683e0d20ee560331d03a5a89207ac57d Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 20 Aug 2025 10:49:07 -0500 Subject: [PATCH 3/9] Style --- src/transformers/integrations/executorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index b350353be525..710383b08fad 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -505,7 +505,7 @@ def __init__( "Please set `generation_config.cache_implementation='static'`." ) - + cache_config = {} if generation_config.cache_config is None else generation_config.cache_config batch_size = cache_config.get("batch_size", 1) max_cache_len = cache_config.get("max_cache_len", 4096) From 66b5cc365d67e5dcfb1845179fb0eeacc8239f5a Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 20 Aug 2025 10:56:16 -0500 Subject: [PATCH 4/9] Styyyle --- src/transformers/integrations/executorch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 710383b08fad..09ca55f800c2 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -505,7 +505,6 @@ def __init__( "Please set `generation_config.cache_implementation='static'`." ) - cache_config = {} if generation_config.cache_config is None else generation_config.cache_config batch_size = cache_config.get("batch_size", 1) max_cache_len = cache_config.get("max_cache_len", 4096) From d5d2f66bd39c6115dbe877d1fcb79eab2f583bd7 Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 26 Aug 2025 03:38:11 -0500 Subject: [PATCH 5/9] Removed default and added args --- src/transformers/integrations/executorch.py | 70 ++++++++++++++++----- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 09ca55f800c2..222cc947709e 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -471,17 +471,27 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - ): + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Optional[torch.device] = None, + ) -> None: """ Initializes the wrapper module with the pretrained model. Args: model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching enabled and use a 'static' caching implementation. + batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found + in `generation_config.cache_config` and otherwise we raise a ValueError. + max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if + not provided. + device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found + in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised). Raises: AssertionError: If the pretrained model does not have caching enabled or if it does not use a 'static' caching implementation in `model.generation_config`. + ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`. """ super().__init__() @@ -506,18 +516,25 @@ def __init__( ) cache_config = {} if generation_config.cache_config is None else generation_config.cache_config - batch_size = cache_config.get("batch_size", 1) - max_cache_len = cache_config.get("max_cache_len", 4096) - device = cache_config.get("device", self.model.device) + # Ensure batch_size and max_cache_len are set + if batch_size is None: + batch_size = cache_config.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size must be provided, either as an argument or in cache_config.") + if max_cache_len is None: + max_cache_len = cache_config.get("max_cache_len", None) + if max_cache_len is None: + raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.") + # Infer device if not provided + if device is None: + device = cache_config.get("device", model.device) + + # Initialize the static cache self.model = model - self.static_cache = StaticCache( - max_cache_len=max_cache_len, - config=config, - ) + self.static_cache = StaticCache(max_cache_len=max_cache_len, config=config) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - device = device dtype = self.model.dtype # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable) self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device) @@ -633,34 +650,53 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - ): + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Optional[torch.device] = None, + ) -> None: """ Initializes the exportable module. Args: model (`PreTrainedModel`): The pretrained model to wrap. - + batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found + in `generation_config.cache_config` and otherwise we raise a ValueError. + max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if + not provided. + device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found + in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised). Raises: - AssertionError: If the model doesn't have the expected configuration for an hybrid StaticCache. + AssertionError: If the model doesn't have the expected configuration for hybrid StaticCache. + ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`. """ super().__init__() self.model = model config = model.config.get_text_config() generation_config = model.generation_config + # Sanity checks if generation_config is None: raise AssertionError( "The model must have a generation config to be exported with static caching. " "Please set `generation_config` in `model`." ) - cache_config = {} if generation_config.cache_config is None else generation_config.cache_config - batch_size = cache_config.get("batch_size", 1) - max_cache_len = cache_config.get("max_cache_len", 4096) - device = cache_config.get("device", self.model.device) - if not config.use_cache: raise AssertionError("Model must have caching enabled.") + cache_config = {} if generation_config.cache_config is None else generation_config.cache_config + # Ensure batch_size and max_cache_len are set + if batch_size is None: + batch_size = cache_config.get("batch_size", None) + if batch_size is None: + raise ValueError("batch_size must be provided, either as an argument or in cache_config.") + if max_cache_len is None: + max_cache_len = cache_config.get("max_cache_len", None) + if max_cache_len is None: + raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.") + # Infer device if not provided + if device is None: + device = cache_config.get("device", model.device) + # Initialize the cache self.cache = StaticCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len")) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) From 7ca87b9296c3221c8649aecb468e506b6aa02091 Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 26 Aug 2025 03:47:28 -0500 Subject: [PATCH 6/9] Rebase mishapfix --- src/transformers/integrations/executorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 222cc947709e..62533fc24a98 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -698,7 +698,7 @@ def __init__( device = cache_config.get("device", model.device) # Initialize the cache - self.cache = StaticCache(config=config, max_cache_len=generation_config.cache_config.get("max_cache_len")) + self.cache = StaticCache(config=config, max_cache_len=max_cache_len) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) dtype = self.model.dtype From 652762185e4a775154dbd14c4c0c6d40ea4178b6 Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 26 Aug 2025 03:57:12 -0500 Subject: [PATCH 7/9] Propagate args to TorchExportableModuleForDecoderOnlyLM --- src/transformers/integrations/executorch.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 62533fc24a98..cd9b12847f3a 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -201,7 +201,10 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): def __init__( self, model: PreTrainedModel, - ): + batch_size: Optional[int] = None, + max_cache_len: Optional[int] = None, + device: Optional[torch.device] = None, + ) -> None: """ Initializes the exportable module. @@ -214,20 +217,19 @@ def __init__( super().__init__() config = model.config.get_text_config() - _generation_config = model.generation_config if not hasattr(config, "use_cache") or config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.") if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None: - self.model = TorchExportableModuleWithHybridCache(model) + self.model = TorchExportableModuleWithHybridCache(model, batch_size, max_cache_len, device) else: # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, # there is only 1 type of layers, so export will use `StaticCache` by default. logging.info( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) - self.model = TorchExportableModuleWithStaticCache(model) + self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device) # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) From 8debe26e33568a0954379ebdf2c01635028cab51 Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 26 Aug 2025 03:57:52 -0500 Subject: [PATCH 8/9] Fix the test I wanted fixed in this PR --- tests/models/gemma3/test_modeling_gemma3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index f1e1a5a95fdd..99ea97561f78 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -733,7 +733,7 @@ def test_export_text_only_with_hybrid_cache(self): # Export + hybrid cache model.eval() - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model, batch_size=1, max_cache_len=1024) exported_program = exportable_module.export( input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), cache_position=torch.tensor([0], dtype=torch.long, device=model.device), From 5277eaddf2555255898c2509202b4df433d31564 Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 26 Aug 2025 04:45:55 -0500 Subject: [PATCH 9/9] Added some AMD expectation related to cache tests --- tests/models/gemma/test_modeling_gemma.py | 3 +++ tests/models/gemma2/test_modeling_gemma2.py | 5 ++++- tests/models/qwen2/test_modeling_qwen2.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 284cd4c19909..097c82a0e5a0 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -416,6 +416,9 @@ def test_export_static_cache(self): ("cuda", 8): [ "Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have been looking on the internet and I have" ], + ("rocm", (9, 5)): [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have been looking on the internet and I have" + ], } ) EXPECTED_TEXT_COMPLETION = expectations.get_expectation() diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 59921594d691..423640cb31b1 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -255,6 +255,9 @@ def test_export_static_cache(self): ("cuda", 8): [ "Hello I am doing a project for my class and I am having trouble with the code. I am trying to make a" ], + ("rocm", (9, 5)): [ + "Hello I am doing a project for my school and I need to know how to make a program that will take a number" + ], } ) EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation() @@ -320,7 +323,7 @@ def test_export_hybrid_cache(self): # Export + hybrid cache model.eval() - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model, batch_size=1, max_cache_len=1024) exported_program = exportable_module.export( input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), cache_position=torch.tensor([0], dtype=torch.long, device=model.device), diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 74da981d092a..30c5082393ef 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -257,7 +257,7 @@ def test_export_static_cache(self): "My favourite condiment is 100% natural, organic, gluten free, vegan, and vegetarian. I love to use" ], ("rocm", (9, 5)): [ - "My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking, but" + "My favourite condiment is 100% natural, organic, gluten free, vegan, and vegetarian. I love to use" ] }) # fmt: off EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()