Skip to content

Commit 29bb713

Browse files
authored
CodeGen Architecture Adapter (#1242)
* Initial CodeGen setup * Mypy and check format
1 parent 93e5b4c commit 29bb713

File tree

13 files changed

+2518
-2131
lines changed

13 files changed

+2518
-2131
lines changed

tests/unit/model_bridge/generalized_components/test_codegen_attention_bridge.py

Lines changed: 545 additions & 0 deletions
Large diffs are not rendered by default.

tests/unit/model_bridge/supported_architectures/__init__.py

Whitespace-only changes.
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
"""Unit tests for CodeGenArchitectureAdapter.
2+
3+
Tests cover:
4+
- Config attribute validation (all required attributes are set correctly)
5+
- Component mapping structure (correct bridge types, no ln2)
6+
- Weight conversion keys and structure
7+
- split_qkv_matrix correctness (numerical test with known weights)
8+
- Factory registration (CodeGenForCausalLM maps to the right adapter)
9+
"""
10+
11+
from types import SimpleNamespace
12+
from typing import Any
13+
14+
import pytest
15+
import torch
16+
import torch.nn as nn
17+
18+
from transformer_lens.config import TransformerBridgeConfig
19+
from transformer_lens.model_bridge.generalized_components import (
20+
BlockBridge,
21+
CodeGenAttentionBridge,
22+
EmbeddingBridge,
23+
MLPBridge,
24+
NormalizationBridge,
25+
UnembeddingBridge,
26+
)
27+
from transformer_lens.model_bridge.supported_architectures.codegen import (
28+
CodeGenArchitectureAdapter,
29+
)
30+
31+
# ---------------------------------------------------------------------------
32+
# Fixtures
33+
# ---------------------------------------------------------------------------
34+
35+
36+
def _make_cfg(
37+
n_heads: int = 4,
38+
d_model: int = 64,
39+
n_layers: int = 2,
40+
d_mlp: int = 256,
41+
d_vocab: int = 1000,
42+
n_ctx: int = 512,
43+
) -> TransformerBridgeConfig:
44+
"""Return a minimal TransformerBridgeConfig for CodeGen adapter tests."""
45+
return TransformerBridgeConfig(
46+
d_model=d_model,
47+
d_head=d_model // n_heads,
48+
n_layers=n_layers,
49+
n_ctx=n_ctx,
50+
n_heads=n_heads,
51+
d_vocab=d_vocab,
52+
d_mlp=d_mlp,
53+
default_prepend_bos=True,
54+
architecture="CodeGenForCausalLM",
55+
)
56+
57+
58+
@pytest.fixture
59+
def cfg() -> TransformerBridgeConfig:
60+
return _make_cfg()
61+
62+
63+
@pytest.fixture
64+
def adapter(cfg: TransformerBridgeConfig) -> CodeGenArchitectureAdapter:
65+
return CodeGenArchitectureAdapter(cfg)
66+
67+
68+
# ---------------------------------------------------------------------------
69+
# Config attribute tests
70+
# ---------------------------------------------------------------------------
71+
72+
73+
class TestCodeGenAdapterConfig:
74+
"""Tests that the adapter sets required config attributes correctly."""
75+
76+
def test_normalization_type_is_ln(self, adapter: CodeGenArchitectureAdapter) -> None:
77+
assert adapter.cfg.normalization_type == "LN"
78+
79+
def test_positional_embedding_type_is_rotary(self, adapter: CodeGenArchitectureAdapter) -> None:
80+
assert adapter.cfg.positional_embedding_type == "rotary"
81+
82+
def test_final_rms_is_false(self, adapter: CodeGenArchitectureAdapter) -> None:
83+
assert adapter.cfg.final_rms is False
84+
85+
def test_gated_mlp_is_false(self, adapter: CodeGenArchitectureAdapter) -> None:
86+
assert adapter.cfg.gated_mlp is False
87+
88+
def test_attn_only_is_false(self, adapter: CodeGenArchitectureAdapter) -> None:
89+
assert adapter.cfg.attn_only is False
90+
91+
def test_parallel_attn_mlp_is_true(self, adapter: CodeGenArchitectureAdapter) -> None:
92+
assert adapter.cfg.parallel_attn_mlp is True
93+
94+
95+
# ---------------------------------------------------------------------------
96+
# Component mapping structure tests
97+
# ---------------------------------------------------------------------------
98+
99+
100+
class TestCodeGenAdapterComponentMapping:
101+
"""Tests that component_mapping has the correct bridge types and structure."""
102+
103+
def test_embed_is_embedding_bridge(self, adapter: CodeGenArchitectureAdapter) -> None:
104+
assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge)
105+
106+
def test_embed_name(self, adapter: CodeGenArchitectureAdapter) -> None:
107+
assert adapter.component_mapping["embed"].name == "transformer.wte"
108+
109+
def test_blocks_is_block_bridge(self, adapter: CodeGenArchitectureAdapter) -> None:
110+
assert isinstance(adapter.component_mapping["blocks"], BlockBridge)
111+
112+
def test_blocks_name(self, adapter: CodeGenArchitectureAdapter) -> None:
113+
assert adapter.component_mapping["blocks"].name == "transformer.h"
114+
115+
def test_ln_final_is_normalization_bridge(self, adapter: CodeGenArchitectureAdapter) -> None:
116+
assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge)
117+
118+
def test_ln_final_name(self, adapter: CodeGenArchitectureAdapter) -> None:
119+
assert adapter.component_mapping["ln_final"].name == "transformer.ln_f"
120+
121+
def test_unembed_is_unembedding_bridge(self, adapter: CodeGenArchitectureAdapter) -> None:
122+
assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge)
123+
124+
def test_unembed_name(self, adapter: CodeGenArchitectureAdapter) -> None:
125+
assert adapter.component_mapping["unembed"].name == "lm_head"
126+
127+
def test_blocks_ln1_is_normalization_bridge(self, adapter: CodeGenArchitectureAdapter) -> None:
128+
blocks = adapter.component_mapping["blocks"]
129+
assert isinstance(blocks.submodules["ln1"], NormalizationBridge)
130+
131+
def test_blocks_ln1_name(self, adapter: CodeGenArchitectureAdapter) -> None:
132+
blocks = adapter.component_mapping["blocks"]
133+
assert blocks.submodules["ln1"].name == "ln_1"
134+
135+
def test_no_ln2_in_blocks(self, adapter: CodeGenArchitectureAdapter) -> None:
136+
"""CodeGen uses parallel attn+MLP sharing ln_1 — there must be no ln2."""
137+
blocks = adapter.component_mapping["blocks"]
138+
assert "ln2" not in blocks.submodules, "CodeGen parallel block must not have ln2"
139+
140+
def test_attn_is_codegen_attention_bridge(self, adapter: CodeGenArchitectureAdapter) -> None:
141+
blocks = adapter.component_mapping["blocks"]
142+
assert isinstance(blocks.submodules["attn"], CodeGenAttentionBridge)
143+
144+
def test_attn_name(self, adapter: CodeGenArchitectureAdapter) -> None:
145+
blocks = adapter.component_mapping["blocks"]
146+
assert blocks.submodules["attn"].name == "attn"
147+
148+
def test_mlp_is_mlp_bridge(self, adapter: CodeGenArchitectureAdapter) -> None:
149+
blocks = adapter.component_mapping["blocks"]
150+
assert isinstance(blocks.submodules["mlp"], MLPBridge)
151+
152+
def test_mlp_name(self, adapter: CodeGenArchitectureAdapter) -> None:
153+
blocks = adapter.component_mapping["blocks"]
154+
assert blocks.submodules["mlp"].name == "mlp"
155+
156+
def test_mlp_in_name(self, adapter: CodeGenArchitectureAdapter) -> None:
157+
blocks = adapter.component_mapping["blocks"]
158+
assert blocks.submodules["mlp"].submodules["in"].name == "fc_in"
159+
160+
def test_mlp_out_name(self, adapter: CodeGenArchitectureAdapter) -> None:
161+
blocks = adapter.component_mapping["blocks"]
162+
assert blocks.submodules["mlp"].submodules["out"].name == "fc_out"
163+
164+
165+
# ---------------------------------------------------------------------------
166+
# Weight processing conversion tests
167+
# ---------------------------------------------------------------------------
168+
169+
170+
class TestCodeGenAdapterWeightConversions:
171+
"""Tests that weight_processing_conversions has the expected keys."""
172+
173+
def test_q_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None:
174+
assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions
175+
176+
def test_k_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None:
177+
assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions
178+
179+
def test_v_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None:
180+
assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions
181+
182+
def test_o_weight_key_present(self, adapter: CodeGenArchitectureAdapter) -> None:
183+
assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions
184+
185+
def test_exactly_four_conversion_keys(self, adapter: CodeGenArchitectureAdapter) -> None:
186+
assert len(adapter.weight_processing_conversions) == 4
187+
188+
189+
# ---------------------------------------------------------------------------
190+
# split_qkv_matrix numerical correctness tests
191+
# ---------------------------------------------------------------------------
192+
193+
194+
class TestCodeGenSplitQKVMatrix:
195+
"""Numerical tests verifying the mp_num=4 QKV split logic."""
196+
197+
def _make_adapter_with_dmodel(self, d_model: int, n_heads: int) -> CodeGenArchitectureAdapter:
198+
cfg = _make_cfg(d_model=d_model, n_heads=n_heads)
199+
return CodeGenArchitectureAdapter(cfg)
200+
201+
def _make_attn_component(self, d_model: int) -> Any:
202+
"""Create a minimal attn component with a qkv_proj linear."""
203+
attn = SimpleNamespace()
204+
attn.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False)
205+
return attn
206+
207+
def test_returns_three_linear_modules(self) -> None:
208+
"""split_qkv_matrix must return exactly three nn.Linear modules."""
209+
adapter = self._make_adapter_with_dmodel(64, 4)
210+
attn = self._make_attn_component(64)
211+
q, k, v = adapter.split_qkv_matrix(attn)
212+
assert isinstance(q, nn.Linear)
213+
assert isinstance(k, nn.Linear)
214+
assert isinstance(v, nn.Linear)
215+
216+
def test_output_shapes_are_correct(self) -> None:
217+
"""Each of Q, K, V must have weight shape [n_embd, n_embd]."""
218+
d_model = 64
219+
adapter = self._make_adapter_with_dmodel(d_model, 4)
220+
attn = self._make_attn_component(d_model)
221+
q, k, v = adapter.split_qkv_matrix(attn)
222+
assert q.weight.shape == (d_model, d_model)
223+
assert k.weight.shape == (d_model, d_model)
224+
assert v.weight.shape == (d_model, d_model)
225+
226+
def test_no_bias_on_outputs(self) -> None:
227+
"""The split linears must have no bias, matching qkv_proj."""
228+
adapter = self._make_adapter_with_dmodel(64, 4)
229+
attn = self._make_attn_component(64)
230+
q, k, v = adapter.split_qkv_matrix(attn)
231+
assert q.bias is None
232+
assert k.bias is None
233+
assert v.bias is None
234+
235+
def test_q_k_v_are_distinct(self) -> None:
236+
"""With a non-trivial weight, Q, K, V must differ from each other."""
237+
adapter = self._make_adapter_with_dmodel(64, 4)
238+
attn = self._make_attn_component(64)
239+
# Fill qkv_proj with distinct values per row
240+
nn.init.normal_(attn.qkv_proj.weight)
241+
q, k, v = adapter.split_qkv_matrix(attn)
242+
# All three must differ
243+
assert not torch.allclose(q.weight, k.weight), "Q and K weights must differ"
244+
assert not torch.allclose(q.weight, v.weight), "Q and V weights must differ"
245+
assert not torch.allclose(k.weight, v.weight), "K and V weights must differ"
246+
247+
def test_known_partition_ordering(self) -> None:
248+
"""Verify the mp_num=4 partition layout: within each partition [Q_part, V_part, K_part].
249+
250+
We construct a weight where partition index and slot index are embedded
251+
in the values, then verify that Q, K, V extract the correct slices.
252+
"""
253+
mp_num = 4
254+
d_model = 64
255+
n_heads = 4
256+
local_dim = d_model // mp_num # 16
257+
258+
adapter = self._make_adapter_with_dmodel(d_model, n_heads)
259+
attn = self._make_attn_component(d_model)
260+
261+
# Build a structured weight: rows are indexed 0..3*d_model-1.
262+
# Reshape as [mp_num=4, 3, local_dim=16, d_model=64], set each slice
263+
# to a unique constant so we can track which slot goes where.
264+
w = torch.zeros(mp_num, 3, local_dim, d_model)
265+
# slot 0 = Q_part → fill with 1.0
266+
w[:, 0, :, :] = 1.0
267+
# slot 1 = V_part → fill with 2.0
268+
w[:, 1, :, :] = 2.0
269+
# slot 2 = K_part → fill with 3.0
270+
w[:, 2, :, :] = 3.0
271+
272+
# Flatten back to [3*d_model, d_model] as qkv_proj expects
273+
attn.qkv_proj.weight = nn.Parameter(w.reshape(3 * d_model, d_model))
274+
275+
q, k, v = adapter.split_qkv_matrix(attn)
276+
277+
assert torch.all(q.weight == 1.0), "Q should come from slot 0 (Q_part)"
278+
assert torch.all(k.weight == 3.0), "K should come from slot 2 (K_part)"
279+
assert torch.all(v.weight == 2.0), "V should come from slot 1 (V_part)"
280+
281+
def test_forward_output_shape_with_split(self) -> None:
282+
"""After split, Q/K/V linears should produce correct output shapes."""
283+
d_model = 64
284+
adapter = self._make_adapter_with_dmodel(d_model, 4)
285+
attn = self._make_attn_component(d_model)
286+
q_lin, k_lin, v_lin = adapter.split_qkv_matrix(attn)
287+
288+
batch, seq = 2, 10
289+
x = torch.randn(batch, seq, d_model)
290+
assert q_lin(x).shape == (batch, seq, d_model)
291+
assert k_lin(x).shape == (batch, seq, d_model)
292+
assert v_lin(x).shape == (batch, seq, d_model)
293+
294+
295+
# ---------------------------------------------------------------------------
296+
# Factory registration test
297+
# ---------------------------------------------------------------------------
298+
299+
300+
class TestCodeGenFactoryRegistration:
301+
"""Tests that the factory maps CodeGenForCausalLM to the correct adapter.
302+
303+
Note: Phase D (registration) is required for these tests to pass. They
304+
are included here so that registration is verified as part of the Phase D
305+
commit rather than needing a separate test file.
306+
"""
307+
308+
def test_factory_returns_codegen_adapter(self) -> None:
309+
"""ArchitectureAdapterFactory must return a CodeGenArchitectureAdapter."""
310+
from transformer_lens.factories.architecture_adapter_factory import (
311+
ArchitectureAdapterFactory,
312+
)
313+
314+
cfg = _make_cfg()
315+
adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
316+
assert isinstance(
317+
adapter, CodeGenArchitectureAdapter
318+
), f"Expected CodeGenArchitectureAdapter, got {type(adapter).__name__}"
319+
320+
def test_factory_key_is_codegen_for_causal_lm(self) -> None:
321+
"""SUPPORTED_ARCHITECTURES must have a 'CodeGenForCausalLM' key."""
322+
from transformer_lens.factories.architecture_adapter_factory import (
323+
SUPPORTED_ARCHITECTURES,
324+
)
325+
326+
assert (
327+
"CodeGenForCausalLM" in SUPPORTED_ARCHITECTURES
328+
), "CodeGenForCausalLM must be registered in SUPPORTED_ARCHITECTURES"

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ApertusArchitectureAdapter,
1010
BertArchitectureAdapter,
1111
BloomArchitectureAdapter,
12+
CodeGenArchitectureAdapter,
1213
FalconArchitectureAdapter,
1314
Gemma1ArchitectureAdapter,
1415
Gemma2ArchitectureAdapter,
@@ -53,6 +54,7 @@
5354
"ApertusForCausalLM": ApertusArchitectureAdapter,
5455
"BertForMaskedLM": BertArchitectureAdapter,
5556
"BloomForCausalLM": BloomArchitectureAdapter,
57+
"CodeGenForCausalLM": CodeGenArchitectureAdapter,
5658
"FalconForCausalLM": FalconArchitectureAdapter,
5759
"GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
5860
"Gemma1ForCausalLM": Gemma1ArchitectureAdapter,

transformer_lens/model_bridge/generalized_components/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from transformer_lens.model_bridge.generalized_components.bloom_attention import (
1010
BloomAttentionBridge,
1111
)
12+
from transformer_lens.model_bridge.generalized_components.codegen_attention import (
13+
CodeGenAttentionBridge,
14+
)
1215
from transformer_lens.model_bridge.generalized_components.bloom_block import (
1316
BloomBlockBridge,
1417
)
@@ -78,6 +81,7 @@
7881
"BlockBridge",
7982
"BloomBlockBridge",
8083
"BloomAttentionBridge",
84+
"CodeGenAttentionBridge",
8185
"BloomMLPBridge",
8286
"CLIPVisionEncoderBridge",
8387
"CLIPVisionEncoderLayerBridge",

0 commit comments

Comments
 (0)