Skip to content

Commit 93e5b4c

Browse files
authored
Initial setup of the falcon adapter (#1241)
1 parent 8ff5637 commit 93e5b4c

File tree

16 files changed

+31010
-25543
lines changed

16 files changed

+31010
-25543
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Integration tests for Falcon architecture adapter."""
2+
3+
import pytest
4+
import torch
5+
6+
from transformer_lens.model_bridge.bridge import TransformerBridge
7+
8+
MODEL = "optimum-intel-internal-testing/really-tiny-falcon-testing"
9+
10+
11+
@pytest.fixture(scope="module")
12+
def falcon_bridge():
13+
return TransformerBridge.boot_transformers(MODEL, device="cpu")
14+
15+
16+
class TestFalconBridgeCreation:
17+
def test_block_count(self, falcon_bridge):
18+
assert len(falcon_bridge.blocks) == 2
19+
20+
def test_parallel_mode(self, falcon_bridge):
21+
assert falcon_bridge.cfg.parallel_attn_mlp is True
22+
23+
def test_has_core_components(self, falcon_bridge):
24+
assert hasattr(falcon_bridge, "embed")
25+
assert hasattr(falcon_bridge, "unembed")
26+
assert hasattr(falcon_bridge, "ln_final")
27+
28+
29+
class TestFalconForwardPass:
30+
def test_forward_returns_logits(self, falcon_bridge):
31+
tokens = torch.tensor([[1, 2, 3, 4]])
32+
with torch.no_grad():
33+
output = falcon_bridge(tokens)
34+
assert output.shape[0] == 1
35+
assert output.shape[1] == 4
36+
assert not torch.isnan(output).any()
37+
38+
def test_forward_matches_hf(self, falcon_bridge):
39+
"""Bridge delegates to HF native forward — output should be identical."""
40+
tokens = torch.tensor([[1, 2, 3, 4]])
41+
hf_model = falcon_bridge.original_model
42+
with torch.no_grad():
43+
bridge_out = falcon_bridge(tokens)
44+
hf_out = hf_model(tokens).logits
45+
max_diff = (bridge_out - hf_out).abs().max().item()
46+
assert max_diff < 1e-4, f"Bridge vs HF max diff = {max_diff}"
47+
48+
49+
class TestFalconParallelHooks:
50+
def test_no_hook_resid_mid(self, falcon_bridge):
51+
tokens = torch.tensor([[1, 2, 3, 4]])
52+
_, cache = falcon_bridge.run_with_cache(tokens)
53+
assert not any("hook_resid_mid" in k for k in cache.keys())
54+
55+
def test_attn_and_mlp_hooks_fire(self, falcon_bridge):
56+
tokens = torch.tensor([[1, 2, 3, 4]])
57+
_, cache = falcon_bridge.run_with_cache(tokens)
58+
for i in range(2):
59+
assert f"blocks.{i}.attn.hook_in" in cache
60+
assert f"blocks.{i}.attn.hook_out" in cache
61+
assert f"blocks.{i}.mlp.hook_in" in cache
62+
assert f"blocks.{i}.mlp.hook_out" in cache
63+
64+
def test_residual_hooks_fire(self, falcon_bridge):
65+
tokens = torch.tensor([[1, 2, 3, 4]])
66+
_, cache = falcon_bridge.run_with_cache(tokens)
67+
for i in range(2):
68+
assert f"blocks.{i}.hook_resid_pre" in cache
69+
assert f"blocks.{i}.hook_resid_post" in cache
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""Tests for shared ALiBi utility functions."""
2+
3+
import torch
4+
5+
from transformer_lens.model_bridge.generalized_components.alibi_utils import (
6+
build_alibi_slopes,
7+
build_alibi_tensor,
8+
)
9+
10+
11+
class TestBuildAlibiSlopes:
12+
"""Test ALiBi slope computation against hand-derived values."""
13+
14+
def test_power_of_2_heads_4(self):
15+
# base = 2^(-(2^-(log2(4)-3))) = 2^(-(2^(-1))) = 2^(-2) = 0.25
16+
slopes = build_alibi_slopes(4, torch.device("cpu"))
17+
expected = torch.tensor([0.25, 0.0625, 0.015625, 0.00390625])
18+
assert torch.allclose(slopes, expected, atol=1e-7)
19+
20+
def test_power_of_2_heads_8(self):
21+
# base = 2^(-(2^-(3-3))) = 2^(-1) = 0.5
22+
slopes = build_alibi_slopes(8, torch.device("cpu"))
23+
expected = torch.tensor([0.5**i for i in range(1, 9)])
24+
assert torch.allclose(slopes, expected, atol=1e-7)
25+
26+
def test_non_power_of_2_heads_6(self):
27+
# closest_pow2=4, base=0.25 → first 4
28+
# extra_base=0.5, extra_powers=[1,3] → 0.5^1, 0.5^3
29+
slopes = build_alibi_slopes(6, torch.device("cpu"))
30+
expected = torch.tensor([0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125])
31+
assert slopes.shape == (6,)
32+
assert torch.allclose(slopes, expected, atol=1e-7)
33+
34+
def test_slopes_length_matches_num_heads(self):
35+
for n in [1, 2, 3, 5, 7, 16, 32]:
36+
slopes = build_alibi_slopes(n, torch.device("cpu"))
37+
assert slopes.shape == (n,), f"Failed for n_heads={n}"
38+
39+
def test_all_slopes_positive(self):
40+
slopes = build_alibi_slopes(32, torch.device("cpu"))
41+
assert (slopes > 0).all()
42+
43+
44+
class TestBuildAlibiTensor:
45+
"""Test full ALiBi tensor generation."""
46+
47+
def test_output_shape(self):
48+
mask = torch.ones(2, 8, dtype=torch.long)
49+
alibi = build_alibi_tensor(mask, 4, torch.float32)
50+
assert alibi.shape == (2, 4, 1, 8)
51+
52+
def test_first_position_is_zero(self):
53+
"""Position 0 should always have zero bias regardless of slope."""
54+
mask = torch.ones(1, 4, dtype=torch.long)
55+
alibi = build_alibi_tensor(mask, 8, torch.float32)
56+
assert (alibi[:, :, :, 0] == 0).all()
57+
58+
def test_values_against_hand_computation(self):
59+
"""Verify against manually computed values for 2 heads, seq_len=4."""
60+
# base = 0.0625, slopes = [0.0625, 0.00390625]
61+
# positions = [0, 1, 2, 3]
62+
mask = torch.ones(1, 4, dtype=torch.long)
63+
alibi = build_alibi_tensor(mask, 2, torch.float32)
64+
# shape: [1, 2, 1, 4]
65+
head0 = alibi[0, 0, 0] # [4]
66+
head1 = alibi[0, 1, 0] # [4]
67+
expected_h0 = torch.tensor([0.0, 0.0625, 0.125, 0.1875])
68+
expected_h1 = torch.tensor([0.0, 0.00390625, 0.0078125, 0.01171875])
69+
assert torch.allclose(head0, expected_h0, atol=1e-6)
70+
assert torch.allclose(head1, expected_h1, atol=1e-6)
71+
72+
def test_masked_positions_are_zero(self):
73+
"""Positions where attention_mask=0 should produce zero bias."""
74+
mask = torch.tensor([[1, 1, 0, 0]]) # last 2 positions masked
75+
alibi = build_alibi_tensor(mask, 4, torch.float32)
76+
assert (alibi[:, :, :, 2:] == 0).all()
77+
78+
def test_batch_independence(self):
79+
"""Each batch element should be computed independently."""
80+
mask = torch.ones(3, 6, dtype=torch.long)
81+
alibi = build_alibi_tensor(mask, 4, torch.float32)
82+
# All batch elements have same mask → same alibi
83+
assert torch.allclose(alibi[0], alibi[1])
84+
assert torch.allclose(alibi[1], alibi[2])
85+
86+
def test_matches_hf_falcon_slopes(self):
87+
"""Verify slopes match HF Falcon (the bfloat16-free part of their implementation).
88+
89+
HF Falcon applies a bfloat16 cast to slopes before multiplying with positions,
90+
so full tensor values diverge slightly. We verify that the underlying slope
91+
values (which determine the relative bias per head) are identical.
92+
"""
93+
from transformers.models.falcon.modeling_falcon import (
94+
build_alibi_tensor as hf_falcon_alibi,
95+
)
96+
97+
mask = torch.ones(1, 4, dtype=torch.long)
98+
for n_heads in [8, 16, 32]:
99+
ours = build_alibi_tensor(mask, n_heads, torch.float32)
100+
hf = hf_falcon_alibi(mask, n_heads, torch.float32)
101+
# Extract slope per head from position 1 (avoids bfloat16 compounding)
102+
ours_slopes = ours.reshape(n_heads, 1, 4)[:, 0, 1]
103+
hf_slopes = hf[:, 0, 1]
104+
assert torch.allclose(
105+
ours_slopes, hf_slopes, rtol=0.01
106+
), f"Slope mismatch for {n_heads} heads: max_diff={(ours_slopes - hf_slopes).abs().max()}"
107+
108+
def test_matches_hf_bloom(self):
109+
"""Verify against HuggingFace Bloom's ALiBi implementation."""
110+
from transformers.models.bloom.modeling_bloom import (
111+
build_alibi_tensor as hf_bloom_alibi,
112+
)
113+
114+
mask = torch.ones(1, 16, dtype=torch.long)
115+
for n_heads in [8, 16, 32]:
116+
ours = build_alibi_tensor(mask, n_heads, torch.float32)
117+
hf = hf_bloom_alibi(mask, n_heads, torch.float32)
118+
ours_flat = ours.reshape(n_heads, 1, 16)
119+
assert torch.allclose(
120+
ours_flat, hf, atol=1e-5
121+
), f"Mismatch for {n_heads} heads: max_diff={(ours_flat - hf).abs().max()}"
122+
123+
def test_dtype_preserved(self):
124+
mask = torch.ones(1, 4, dtype=torch.long)
125+
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
126+
alibi = build_alibi_tensor(mask, 4, dtype)
127+
assert alibi.dtype == dtype
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Unit tests for FalconALiBiAttentionBridge.
2+
3+
Exercises the reimplemented ALiBi attention with mock weights — no model download needed.
4+
Covers MHA, MQA, and GQA head configurations to catch shape mismatches.
5+
"""
6+
7+
import torch
8+
9+
from transformer_lens.model_bridge.generalized_components.falcon_alibi_attention import (
10+
FalconALiBiAttentionBridge,
11+
)
12+
13+
14+
class _MockConfig:
15+
"""Minimal config for FalconALiBiAttentionBridge."""
16+
17+
def __init__(self, n_heads: int, d_model: int, n_key_value_heads: int | None = None):
18+
self.n_heads = n_heads
19+
self.d_model = d_model
20+
self.n_key_value_heads = n_key_value_heads
21+
22+
23+
class _MockAttention(torch.nn.Module):
24+
"""Stub original component so the bridge's forward doesn't raise."""
25+
26+
def __init__(self):
27+
super().__init__()
28+
self.attn_dropout = torch.nn.Identity()
29+
30+
def forward(self, x: torch.Tensor) -> torch.Tensor:
31+
return x
32+
33+
34+
def _build_bridge(
35+
n_heads: int, d_model: int, n_key_value_heads: int | None = None
36+
) -> FalconALiBiAttentionBridge:
37+
"""Build a wired-up FalconALiBiAttentionBridge with random Q/K/V weights."""
38+
cfg = _MockConfig(n_heads, d_model, n_key_value_heads)
39+
head_dim = d_model // n_heads
40+
n_kv = n_key_value_heads or n_heads
41+
42+
q_linear = torch.nn.Linear(d_model, n_heads * head_dim)
43+
k_linear = torch.nn.Linear(d_model, n_kv * head_dim)
44+
v_linear = torch.nn.Linear(d_model, n_kv * head_dim)
45+
o_linear = torch.nn.Linear(d_model, d_model)
46+
47+
def split_qkv(_component):
48+
return q_linear, k_linear, v_linear
49+
50+
bridge = FalconALiBiAttentionBridge(
51+
name="self_attention",
52+
config=cfg,
53+
split_qkv_matrix=split_qkv,
54+
)
55+
mock_attn = _MockAttention()
56+
mock_attn.dense = o_linear
57+
bridge.set_original_component(mock_attn)
58+
return bridge
59+
60+
61+
def _random_inputs(bridge: FalconALiBiAttentionBridge, batch: int = 2, seq: int = 6):
62+
"""Generate random inputs via the bridge's own method."""
63+
return bridge.get_random_inputs(batch_size=batch, seq_len=seq)
64+
65+
66+
class TestFalconALiBiForward:
67+
"""Forward pass runs and produces valid output for all head configs."""
68+
69+
def test_mha_forward(self):
70+
"""Standard MHA: n_heads == n_kv_heads."""
71+
bridge = _build_bridge(n_heads=4, d_model=32)
72+
inputs = _random_inputs(bridge)
73+
with torch.no_grad():
74+
output, weights = bridge(
75+
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
76+
)
77+
assert output.shape == (2, 6, 32)
78+
assert not torch.isnan(output).any()
79+
80+
def test_mqa_forward(self):
81+
"""Multi-query: K/V have 1 head, Q has n_heads."""
82+
bridge = _build_bridge(n_heads=8, d_model=64, n_key_value_heads=1)
83+
inputs = _random_inputs(bridge)
84+
with torch.no_grad():
85+
output, weights = bridge(
86+
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
87+
)
88+
assert output.shape == (2, 6, 64)
89+
assert not torch.isnan(output).any()
90+
# Attention weights should have full n_heads after expansion
91+
assert weights.shape[1] == 8
92+
93+
def test_gqa_forward(self):
94+
"""Grouped-query: K/V have fewer heads than Q (but more than 1)."""
95+
bridge = _build_bridge(n_heads=8, d_model=64, n_key_value_heads=2)
96+
inputs = _random_inputs(bridge)
97+
with torch.no_grad():
98+
output, weights = bridge(
99+
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
100+
)
101+
assert output.shape == (2, 6, 64)
102+
assert not torch.isnan(output).any()
103+
assert weights.shape[1] == 8
104+
105+
106+
class TestALiBiEffect:
107+
"""ALiBi bias actually affects attention scores."""
108+
109+
def test_alibi_changes_output(self):
110+
"""Output with ALiBi should differ from output without."""
111+
bridge = _build_bridge(n_heads=4, d_model=32)
112+
inputs = _random_inputs(bridge)
113+
hidden = inputs["hidden_states"]
114+
mask = inputs["attention_mask"]
115+
116+
with torch.no_grad():
117+
out_with, _ = bridge(hidden, alibi=inputs["alibi"], attention_mask=mask)
118+
out_without, _ = bridge(hidden, attention_mask=mask)
119+
120+
assert not torch.allclose(out_with, out_without), "ALiBi should change the output"
121+
122+
def test_pattern_is_causal(self):
123+
"""Upper triangle of attention pattern should be zero (causal masking)."""
124+
bridge = _build_bridge(n_heads=4, d_model=32)
125+
inputs = _random_inputs(bridge, batch=1, seq=4)
126+
127+
with torch.no_grad():
128+
_, weights = bridge(
129+
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
130+
)
131+
# weights: [batch, heads, seq, seq] — upper triangle (above diagonal) should be 0
132+
upper = torch.triu(weights[0, 0], diagonal=1)
133+
assert (upper == 0).all()
134+
135+
136+
class TestHooksFireInForward:
137+
"""Hooks fire correctly during the reimplemented attention forward."""
138+
139+
def test_attn_scores_hook(self):
140+
bridge = _build_bridge(n_heads=4, d_model=32)
141+
inputs = _random_inputs(bridge, batch=1, seq=4)
142+
captured = {}
143+
144+
def hook_fn(tensor, hook):
145+
captured["attn_scores"] = tensor.clone()
146+
return tensor
147+
148+
bridge.hook_attn_scores.add_hook(hook_fn)
149+
with torch.no_grad():
150+
bridge(
151+
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
152+
)
153+
assert "attn_scores" in captured
154+
assert captured["attn_scores"].shape == (1, 4, 4, 4)
155+
156+
def test_pattern_hook(self):
157+
bridge = _build_bridge(n_heads=4, d_model=32)
158+
inputs = _random_inputs(bridge, batch=1, seq=4)
159+
captured = {}
160+
161+
def hook_fn(tensor, hook):
162+
captured["pattern"] = tensor.clone()
163+
return tensor
164+
165+
bridge.hook_pattern.add_hook(hook_fn)
166+
with torch.no_grad():
167+
bridge(
168+
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
169+
)
170+
assert "pattern" in captured
171+
# Pattern rows should sum to 1 (softmax output)
172+
row_sums = captured["pattern"].sum(dim=-1)
173+
assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5)

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ApertusArchitectureAdapter,
1010
BertArchitectureAdapter,
1111
BloomArchitectureAdapter,
12+
FalconArchitectureAdapter,
1213
Gemma1ArchitectureAdapter,
1314
Gemma2ArchitectureAdapter,
1415
Gemma3ArchitectureAdapter,
@@ -52,6 +53,7 @@
5253
"ApertusForCausalLM": ApertusArchitectureAdapter,
5354
"BertForMaskedLM": BertArchitectureAdapter,
5455
"BloomForCausalLM": BloomArchitectureAdapter,
56+
"FalconForCausalLM": FalconArchitectureAdapter,
5557
"GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
5658
"Gemma1ForCausalLM": Gemma1ArchitectureAdapter,
5759
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,

0 commit comments

Comments
 (0)