|
| 1 | +"""Unit tests for CodeGenArchitectureAdapter. |
| 2 | +
|
| 3 | +Tests cover: |
| 4 | +- Config attribute validation (all required attributes are set correctly) |
| 5 | +- Component mapping structure (correct bridge types, no ln2) |
| 6 | +- Weight conversion keys and structure |
| 7 | +- split_qkv_matrix correctness (numerical test with known weights) |
| 8 | +- Factory registration (CodeGenForCausalLM maps to the right adapter) |
| 9 | +""" |
| 10 | + |
| 11 | +from types import SimpleNamespace |
| 12 | +from typing import Any |
| 13 | + |
| 14 | +import pytest |
| 15 | +import torch |
| 16 | +import torch.nn as nn |
| 17 | + |
| 18 | +from transformer_lens.config import TransformerBridgeConfig |
| 19 | +from transformer_lens.model_bridge.generalized_components import ( |
| 20 | + BlockBridge, |
| 21 | + CodeGenAttentionBridge, |
| 22 | + EmbeddingBridge, |
| 23 | + MLPBridge, |
| 24 | + NormalizationBridge, |
| 25 | + UnembeddingBridge, |
| 26 | +) |
| 27 | +from transformer_lens.model_bridge.supported_architectures.codegen import ( |
| 28 | + CodeGenArchitectureAdapter, |
| 29 | +) |
| 30 | + |
| 31 | +# --------------------------------------------------------------------------- |
| 32 | +# Fixtures |
| 33 | +# --------------------------------------------------------------------------- |
| 34 | + |
| 35 | + |
| 36 | +def _make_cfg( |
| 37 | + n_heads: int = 4, |
| 38 | + d_model: int = 64, |
| 39 | + n_layers: int = 2, |
| 40 | + d_mlp: int = 256, |
| 41 | + d_vocab: int = 1000, |
| 42 | + n_ctx: int = 512, |
| 43 | +) -> TransformerBridgeConfig: |
| 44 | + """Return a minimal TransformerBridgeConfig for CodeGen adapter tests.""" |
| 45 | + return TransformerBridgeConfig( |
| 46 | + d_model=d_model, |
| 47 | + d_head=d_model // n_heads, |
| 48 | + n_layers=n_layers, |
| 49 | + n_ctx=n_ctx, |
| 50 | + n_heads=n_heads, |
| 51 | + d_vocab=d_vocab, |
| 52 | + d_mlp=d_mlp, |
| 53 | + default_prepend_bos=True, |
| 54 | + architecture="CodeGenForCausalLM", |
| 55 | + ) |
| 56 | + |
| 57 | + |
| 58 | +@pytest.fixture |
| 59 | +def cfg() -> TransformerBridgeConfig: |
| 60 | + return _make_cfg() |
| 61 | + |
| 62 | + |
| 63 | +@pytest.fixture |
| 64 | +def adapter(cfg: TransformerBridgeConfig) -> CodeGenArchitectureAdapter: |
| 65 | + return CodeGenArchitectureAdapter(cfg) |
| 66 | + |
| 67 | + |
| 68 | +# --------------------------------------------------------------------------- |
| 69 | +# Config attribute tests |
| 70 | +# --------------------------------------------------------------------------- |
| 71 | + |
| 72 | + |
| 73 | +class TestCodeGenAdapterConfig: |
| 74 | + """Tests that the adapter sets required config attributes correctly.""" |
| 75 | + |
| 76 | + def test_normalization_type_is_ln(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 77 | + assert adapter.cfg.normalization_type == "LN" |
| 78 | + |
| 79 | + def test_positional_embedding_type_is_rotary(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 80 | + assert adapter.cfg.positional_embedding_type == "rotary" |
| 81 | + |
| 82 | + def test_final_rms_is_false(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 83 | + assert adapter.cfg.final_rms is False |
| 84 | + |
| 85 | + def test_gated_mlp_is_false(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 86 | + assert adapter.cfg.gated_mlp is False |
| 87 | + |
| 88 | + def test_attn_only_is_false(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 89 | + assert adapter.cfg.attn_only is False |
| 90 | + |
| 91 | + def test_parallel_attn_mlp_is_true(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 92 | + assert adapter.cfg.parallel_attn_mlp is True |
| 93 | + |
| 94 | + |
| 95 | +# --------------------------------------------------------------------------- |
| 96 | +# Component mapping structure tests |
| 97 | +# --------------------------------------------------------------------------- |
| 98 | + |
| 99 | + |
| 100 | +class TestCodeGenAdapterComponentMapping: |
| 101 | + """Tests that component_mapping has the correct bridge types and structure.""" |
| 102 | + |
| 103 | + def test_embed_is_embedding_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 104 | + assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge) |
| 105 | + |
| 106 | + def test_embed_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 107 | + assert adapter.component_mapping["embed"].name == "transformer.wte" |
| 108 | + |
| 109 | + def test_blocks_is_block_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 110 | + assert isinstance(adapter.component_mapping["blocks"], BlockBridge) |
| 111 | + |
| 112 | + def test_blocks_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 113 | + assert adapter.component_mapping["blocks"].name == "transformer.h" |
| 114 | + |
| 115 | + def test_ln_final_is_normalization_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 116 | + assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge) |
| 117 | + |
| 118 | + def test_ln_final_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 119 | + assert adapter.component_mapping["ln_final"].name == "transformer.ln_f" |
| 120 | + |
| 121 | + def test_unembed_is_unembedding_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 122 | + assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge) |
| 123 | + |
| 124 | + def test_unembed_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 125 | + assert adapter.component_mapping["unembed"].name == "lm_head" |
| 126 | + |
| 127 | + def test_blocks_ln1_is_normalization_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 128 | + blocks = adapter.component_mapping["blocks"] |
| 129 | + assert isinstance(blocks.submodules["ln1"], NormalizationBridge) |
| 130 | + |
| 131 | + def test_blocks_ln1_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 132 | + blocks = adapter.component_mapping["blocks"] |
| 133 | + assert blocks.submodules["ln1"].name == "ln_1" |
| 134 | + |
| 135 | + def test_no_ln2_in_blocks(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 136 | + """CodeGen uses parallel attn+MLP sharing ln_1 — there must be no ln2.""" |
| 137 | + blocks = adapter.component_mapping["blocks"] |
| 138 | + assert "ln2" not in blocks.submodules, "CodeGen parallel block must not have ln2" |
| 139 | + |
| 140 | + def test_attn_is_codegen_attention_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 141 | + blocks = adapter.component_mapping["blocks"] |
| 142 | + assert isinstance(blocks.submodules["attn"], CodeGenAttentionBridge) |
| 143 | + |
| 144 | + def test_attn_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 145 | + blocks = adapter.component_mapping["blocks"] |
| 146 | + assert blocks.submodules["attn"].name == "attn" |
| 147 | + |
| 148 | + def test_mlp_is_mlp_bridge(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 149 | + blocks = adapter.component_mapping["blocks"] |
| 150 | + assert isinstance(blocks.submodules["mlp"], MLPBridge) |
| 151 | + |
| 152 | + def test_mlp_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 153 | + blocks = adapter.component_mapping["blocks"] |
| 154 | + assert blocks.submodules["mlp"].name == "mlp" |
| 155 | + |
| 156 | + def test_mlp_in_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 157 | + blocks = adapter.component_mapping["blocks"] |
| 158 | + assert blocks.submodules["mlp"].submodules["in"].name == "fc_in" |
| 159 | + |
| 160 | + def test_mlp_out_name(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 161 | + blocks = adapter.component_mapping["blocks"] |
| 162 | + assert blocks.submodules["mlp"].submodules["out"].name == "fc_out" |
| 163 | + |
| 164 | + |
| 165 | +# --------------------------------------------------------------------------- |
| 166 | +# Weight processing conversion tests |
| 167 | +# --------------------------------------------------------------------------- |
| 168 | + |
| 169 | + |
| 170 | +class TestCodeGenAdapterWeightConversions: |
| 171 | + """Tests that weight_processing_conversions has the expected keys.""" |
| 172 | + |
| 173 | + def test_q_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 174 | + assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions |
| 175 | + |
| 176 | + def test_k_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 177 | + assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions |
| 178 | + |
| 179 | + def test_v_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 180 | + assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions |
| 181 | + |
| 182 | + def test_o_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 183 | + assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions |
| 184 | + |
| 185 | + def test_exactly_four_conversion_keys(self, adapter: CodeGenArchitectureAdapter) -> None: |
| 186 | + assert len(adapter.weight_processing_conversions) == 4 |
| 187 | + |
| 188 | + |
| 189 | +# --------------------------------------------------------------------------- |
| 190 | +# split_qkv_matrix numerical correctness tests |
| 191 | +# --------------------------------------------------------------------------- |
| 192 | + |
| 193 | + |
| 194 | +class TestCodeGenSplitQKVMatrix: |
| 195 | + """Numerical tests verifying the mp_num=4 QKV split logic.""" |
| 196 | + |
| 197 | + def _make_adapter_with_dmodel(self, d_model: int, n_heads: int) -> CodeGenArchitectureAdapter: |
| 198 | + cfg = _make_cfg(d_model=d_model, n_heads=n_heads) |
| 199 | + return CodeGenArchitectureAdapter(cfg) |
| 200 | + |
| 201 | + def _make_attn_component(self, d_model: int) -> Any: |
| 202 | + """Create a minimal attn component with a qkv_proj linear.""" |
| 203 | + attn = SimpleNamespace() |
| 204 | + attn.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False) |
| 205 | + return attn |
| 206 | + |
| 207 | + def test_returns_three_linear_modules(self) -> None: |
| 208 | + """split_qkv_matrix must return exactly three nn.Linear modules.""" |
| 209 | + adapter = self._make_adapter_with_dmodel(64, 4) |
| 210 | + attn = self._make_attn_component(64) |
| 211 | + q, k, v = adapter.split_qkv_matrix(attn) |
| 212 | + assert isinstance(q, nn.Linear) |
| 213 | + assert isinstance(k, nn.Linear) |
| 214 | + assert isinstance(v, nn.Linear) |
| 215 | + |
| 216 | + def test_output_shapes_are_correct(self) -> None: |
| 217 | + """Each of Q, K, V must have weight shape [n_embd, n_embd].""" |
| 218 | + d_model = 64 |
| 219 | + adapter = self._make_adapter_with_dmodel(d_model, 4) |
| 220 | + attn = self._make_attn_component(d_model) |
| 221 | + q, k, v = adapter.split_qkv_matrix(attn) |
| 222 | + assert q.weight.shape == (d_model, d_model) |
| 223 | + assert k.weight.shape == (d_model, d_model) |
| 224 | + assert v.weight.shape == (d_model, d_model) |
| 225 | + |
| 226 | + def test_no_bias_on_outputs(self) -> None: |
| 227 | + """The split linears must have no bias, matching qkv_proj.""" |
| 228 | + adapter = self._make_adapter_with_dmodel(64, 4) |
| 229 | + attn = self._make_attn_component(64) |
| 230 | + q, k, v = adapter.split_qkv_matrix(attn) |
| 231 | + assert q.bias is None |
| 232 | + assert k.bias is None |
| 233 | + assert v.bias is None |
| 234 | + |
| 235 | + def test_q_k_v_are_distinct(self) -> None: |
| 236 | + """With a non-trivial weight, Q, K, V must differ from each other.""" |
| 237 | + adapter = self._make_adapter_with_dmodel(64, 4) |
| 238 | + attn = self._make_attn_component(64) |
| 239 | + # Fill qkv_proj with distinct values per row |
| 240 | + nn.init.normal_(attn.qkv_proj.weight) |
| 241 | + q, k, v = adapter.split_qkv_matrix(attn) |
| 242 | + # All three must differ |
| 243 | + assert not torch.allclose(q.weight, k.weight), "Q and K weights must differ" |
| 244 | + assert not torch.allclose(q.weight, v.weight), "Q and V weights must differ" |
| 245 | + assert not torch.allclose(k.weight, v.weight), "K and V weights must differ" |
| 246 | + |
| 247 | + def test_known_partition_ordering(self) -> None: |
| 248 | + """Verify the mp_num=4 partition layout: within each partition [Q_part, V_part, K_part]. |
| 249 | +
|
| 250 | + We construct a weight where partition index and slot index are embedded |
| 251 | + in the values, then verify that Q, K, V extract the correct slices. |
| 252 | + """ |
| 253 | + mp_num = 4 |
| 254 | + d_model = 64 |
| 255 | + n_heads = 4 |
| 256 | + local_dim = d_model // mp_num # 16 |
| 257 | + |
| 258 | + adapter = self._make_adapter_with_dmodel(d_model, n_heads) |
| 259 | + attn = self._make_attn_component(d_model) |
| 260 | + |
| 261 | + # Build a structured weight: rows are indexed 0..3*d_model-1. |
| 262 | + # Reshape as [mp_num=4, 3, local_dim=16, d_model=64], set each slice |
| 263 | + # to a unique constant so we can track which slot goes where. |
| 264 | + w = torch.zeros(mp_num, 3, local_dim, d_model) |
| 265 | + # slot 0 = Q_part → fill with 1.0 |
| 266 | + w[:, 0, :, :] = 1.0 |
| 267 | + # slot 1 = V_part → fill with 2.0 |
| 268 | + w[:, 1, :, :] = 2.0 |
| 269 | + # slot 2 = K_part → fill with 3.0 |
| 270 | + w[:, 2, :, :] = 3.0 |
| 271 | + |
| 272 | + # Flatten back to [3*d_model, d_model] as qkv_proj expects |
| 273 | + attn.qkv_proj.weight = nn.Parameter(w.reshape(3 * d_model, d_model)) |
| 274 | + |
| 275 | + q, k, v = adapter.split_qkv_matrix(attn) |
| 276 | + |
| 277 | + assert torch.all(q.weight == 1.0), "Q should come from slot 0 (Q_part)" |
| 278 | + assert torch.all(k.weight == 3.0), "K should come from slot 2 (K_part)" |
| 279 | + assert torch.all(v.weight == 2.0), "V should come from slot 1 (V_part)" |
| 280 | + |
| 281 | + def test_forward_output_shape_with_split(self) -> None: |
| 282 | + """After split, Q/K/V linears should produce correct output shapes.""" |
| 283 | + d_model = 64 |
| 284 | + adapter = self._make_adapter_with_dmodel(d_model, 4) |
| 285 | + attn = self._make_attn_component(d_model) |
| 286 | + q_lin, k_lin, v_lin = adapter.split_qkv_matrix(attn) |
| 287 | + |
| 288 | + batch, seq = 2, 10 |
| 289 | + x = torch.randn(batch, seq, d_model) |
| 290 | + assert q_lin(x).shape == (batch, seq, d_model) |
| 291 | + assert k_lin(x).shape == (batch, seq, d_model) |
| 292 | + assert v_lin(x).shape == (batch, seq, d_model) |
| 293 | + |
| 294 | + |
| 295 | +# --------------------------------------------------------------------------- |
| 296 | +# Factory registration test |
| 297 | +# --------------------------------------------------------------------------- |
| 298 | + |
| 299 | + |
| 300 | +class TestCodeGenFactoryRegistration: |
| 301 | + """Tests that the factory maps CodeGenForCausalLM to the correct adapter. |
| 302 | +
|
| 303 | + Note: Phase D (registration) is required for these tests to pass. They |
| 304 | + are included here so that registration is verified as part of the Phase D |
| 305 | + commit rather than needing a separate test file. |
| 306 | + """ |
| 307 | + |
| 308 | + def test_factory_returns_codegen_adapter(self) -> None: |
| 309 | + """ArchitectureAdapterFactory must return a CodeGenArchitectureAdapter.""" |
| 310 | + from transformer_lens.factories.architecture_adapter_factory import ( |
| 311 | + ArchitectureAdapterFactory, |
| 312 | + ) |
| 313 | + |
| 314 | + cfg = _make_cfg() |
| 315 | + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) |
| 316 | + assert isinstance( |
| 317 | + adapter, CodeGenArchitectureAdapter |
| 318 | + ), f"Expected CodeGenArchitectureAdapter, got {type(adapter).__name__}" |
| 319 | + |
| 320 | + def test_factory_key_is_codegen_for_causal_lm(self) -> None: |
| 321 | + """SUPPORTED_ARCHITECTURES must have a 'CodeGenForCausalLM' key.""" |
| 322 | + from transformer_lens.factories.architecture_adapter_factory import ( |
| 323 | + SUPPORTED_ARCHITECTURES, |
| 324 | + ) |
| 325 | + |
| 326 | + assert ( |
| 327 | + "CodeGenForCausalLM" in SUPPORTED_ARCHITECTURES |
| 328 | + ), "CodeGenForCausalLM must be registered in SUPPORTED_ARCHITECTURES" |
0 commit comments