diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index cde061e52..817f2a36f 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -194,7 +194,7 @@ class HookedTransformerConfig: Defaults to 8.0. use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before computing attention scores. Used by Gemma 3 models. Defaults to False. - rotary_base_local (int, *optional*): The base for rotary positional embeddings in local + rotary_base_local (float, *optional*): The base for rotary positional embeddings in local attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3) which use different RoPE bases for local (10k) and global (1M) attention. Defaults to None, which means the standard rotary_base is used for all layers. @@ -252,9 +252,9 @@ class HookedTransformerConfig: tokenizer_prepends_bos: Optional[bool] = None n_key_value_heads: Optional[int] = None post_embedding_ln: bool = False - rotary_base: int = 10000 + rotary_base: Union[float, int] = 10000 rotary_base_local: Optional[ - int + Union[float, int] ] = None # For models with different RoPE bases per attention type (e.g., Gemma 3) trust_remote_code: bool = False rotary_adjacent_pairs: bool = False diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index f9af85637..4f0ad5610 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -532,7 +532,7 @@ def calculate_sin_cos_rotary( self, rotary_dim: int, n_ctx: int, - base: int = 10000, + base: Union[float, int] = 10000, dtype: torch.dtype = torch.float32, ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: """