From a36e7c92a1951c2084143ee038c8e9f581182f18 Mon Sep 17 00:00:00 2001 From: Brendan Long Date: Thu, 2 Apr 2026 22:37:52 -0700 Subject: [PATCH] Fix HookedTransformerConfig rotary_base types rotary_base is frequently set to floats in the code but was typed as an int, causing beartype errors if the configs get loaded in a test: https://github.com/TransformerLensOrg/TransformerLens/blob/9c5a2a81674d5bcefa641c816b66e9827ccdf637/transformer_lens/loading_from_pretrained.py#L1984 HF confgs' allegedly always have rope_theta as a float: https://github.com/huggingface/transformers/blob/c38b2fb78eaedd4261a0e446f7976345cd1c7f1b/src/transformers/modeling_rope_utils.py#L645 But sometimes they're actually ints, and beartype doesn't consider int to be a subtype of float: https://github.com/beartype/beartype/issues/66 This updates the type to Union[float, int] to be accurate while keeping beartype happy. --- transformer_lens/HookedTransformerConfig.py | 6 +++--- transformer_lens/components/abstract_attention.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index cde061e52..817f2a36f 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -194,7 +194,7 @@ class HookedTransformerConfig: Defaults to 8.0. use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before computing attention scores. Used by Gemma 3 models. Defaults to False. - rotary_base_local (int, *optional*): The base for rotary positional embeddings in local + rotary_base_local (float, *optional*): The base for rotary positional embeddings in local attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3) which use different RoPE bases for local (10k) and global (1M) attention. Defaults to None, which means the standard rotary_base is used for all layers. @@ -252,9 +252,9 @@ class HookedTransformerConfig: tokenizer_prepends_bos: Optional[bool] = None n_key_value_heads: Optional[int] = None post_embedding_ln: bool = False - rotary_base: int = 10000 + rotary_base: Union[float, int] = 10000 rotary_base_local: Optional[ - int + Union[float, int] ] = None # For models with different RoPE bases per attention type (e.g., Gemma 3) trust_remote_code: bool = False rotary_adjacent_pairs: bool = False diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index f9af85637..4f0ad5610 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -532,7 +532,7 @@ def calculate_sin_cos_rotary( self, rotary_dim: int, n_ctx: int, - base: int = 10000, + base: Union[float, int] = 10000, dtype: torch.dtype = torch.float32, ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: """