Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/integration/model_bridge/test_falcon_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Integration tests for Falcon architecture adapter."""

import pytest
import torch

from transformer_lens.model_bridge.bridge import TransformerBridge

MODEL = "optimum-intel-internal-testing/really-tiny-falcon-testing"


@pytest.fixture(scope="module")
def falcon_bridge():
return TransformerBridge.boot_transformers(MODEL, device="cpu")


class TestFalconBridgeCreation:
def test_block_count(self, falcon_bridge):
assert len(falcon_bridge.blocks) == 2

def test_parallel_mode(self, falcon_bridge):
assert falcon_bridge.cfg.parallel_attn_mlp is True

def test_has_core_components(self, falcon_bridge):
assert hasattr(falcon_bridge, "embed")
assert hasattr(falcon_bridge, "unembed")
assert hasattr(falcon_bridge, "ln_final")


class TestFalconForwardPass:
def test_forward_returns_logits(self, falcon_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
with torch.no_grad():
output = falcon_bridge(tokens)
assert output.shape[0] == 1
assert output.shape[1] == 4
assert not torch.isnan(output).any()

def test_forward_matches_hf(self, falcon_bridge):
"""Bridge delegates to HF native forward — output should be identical."""
tokens = torch.tensor([[1, 2, 3, 4]])
hf_model = falcon_bridge.original_model
with torch.no_grad():
bridge_out = falcon_bridge(tokens)
hf_out = hf_model(tokens).logits
max_diff = (bridge_out - hf_out).abs().max().item()
assert max_diff < 1e-4, f"Bridge vs HF max diff = {max_diff}"


class TestFalconParallelHooks:
def test_no_hook_resid_mid(self, falcon_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = falcon_bridge.run_with_cache(tokens)
assert not any("hook_resid_mid" in k for k in cache.keys())

def test_attn_and_mlp_hooks_fire(self, falcon_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = falcon_bridge.run_with_cache(tokens)
for i in range(2):
assert f"blocks.{i}.attn.hook_in" in cache
assert f"blocks.{i}.attn.hook_out" in cache
assert f"blocks.{i}.mlp.hook_in" in cache
assert f"blocks.{i}.mlp.hook_out" in cache

def test_residual_hooks_fire(self, falcon_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = falcon_bridge.run_with_cache(tokens)
for i in range(2):
assert f"blocks.{i}.hook_resid_pre" in cache
assert f"blocks.{i}.hook_resid_post" in cache
127 changes: 127 additions & 0 deletions tests/unit/model_bridge/generalized_components/test_alibi_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Tests for shared ALiBi utility functions."""

import torch

from transformer_lens.model_bridge.generalized_components.alibi_utils import (
build_alibi_slopes,
build_alibi_tensor,
)


class TestBuildAlibiSlopes:
"""Test ALiBi slope computation against hand-derived values."""

def test_power_of_2_heads_4(self):
# base = 2^(-(2^-(log2(4)-3))) = 2^(-(2^(-1))) = 2^(-2) = 0.25
slopes = build_alibi_slopes(4, torch.device("cpu"))
expected = torch.tensor([0.25, 0.0625, 0.015625, 0.00390625])
assert torch.allclose(slopes, expected, atol=1e-7)

def test_power_of_2_heads_8(self):
# base = 2^(-(2^-(3-3))) = 2^(-1) = 0.5
slopes = build_alibi_slopes(8, torch.device("cpu"))
expected = torch.tensor([0.5**i for i in range(1, 9)])
assert torch.allclose(slopes, expected, atol=1e-7)

def test_non_power_of_2_heads_6(self):
# closest_pow2=4, base=0.25 → first 4
# extra_base=0.5, extra_powers=[1,3] → 0.5^1, 0.5^3
slopes = build_alibi_slopes(6, torch.device("cpu"))
expected = torch.tensor([0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125])
assert slopes.shape == (6,)
assert torch.allclose(slopes, expected, atol=1e-7)

def test_slopes_length_matches_num_heads(self):
for n in [1, 2, 3, 5, 7, 16, 32]:
slopes = build_alibi_slopes(n, torch.device("cpu"))
assert slopes.shape == (n,), f"Failed for n_heads={n}"

def test_all_slopes_positive(self):
slopes = build_alibi_slopes(32, torch.device("cpu"))
assert (slopes > 0).all()


class TestBuildAlibiTensor:
"""Test full ALiBi tensor generation."""

def test_output_shape(self):
mask = torch.ones(2, 8, dtype=torch.long)
alibi = build_alibi_tensor(mask, 4, torch.float32)
assert alibi.shape == (2, 4, 1, 8)

def test_first_position_is_zero(self):
"""Position 0 should always have zero bias regardless of slope."""
mask = torch.ones(1, 4, dtype=torch.long)
alibi = build_alibi_tensor(mask, 8, torch.float32)
assert (alibi[:, :, :, 0] == 0).all()

def test_values_against_hand_computation(self):
"""Verify against manually computed values for 2 heads, seq_len=4."""
# base = 0.0625, slopes = [0.0625, 0.00390625]
# positions = [0, 1, 2, 3]
mask = torch.ones(1, 4, dtype=torch.long)
alibi = build_alibi_tensor(mask, 2, torch.float32)
# shape: [1, 2, 1, 4]
head0 = alibi[0, 0, 0] # [4]
head1 = alibi[0, 1, 0] # [4]
expected_h0 = torch.tensor([0.0, 0.0625, 0.125, 0.1875])
expected_h1 = torch.tensor([0.0, 0.00390625, 0.0078125, 0.01171875])
assert torch.allclose(head0, expected_h0, atol=1e-6)
assert torch.allclose(head1, expected_h1, atol=1e-6)

def test_masked_positions_are_zero(self):
"""Positions where attention_mask=0 should produce zero bias."""
mask = torch.tensor([[1, 1, 0, 0]]) # last 2 positions masked
alibi = build_alibi_tensor(mask, 4, torch.float32)
assert (alibi[:, :, :, 2:] == 0).all()

def test_batch_independence(self):
"""Each batch element should be computed independently."""
mask = torch.ones(3, 6, dtype=torch.long)
alibi = build_alibi_tensor(mask, 4, torch.float32)
# All batch elements have same mask → same alibi
assert torch.allclose(alibi[0], alibi[1])
assert torch.allclose(alibi[1], alibi[2])

def test_matches_hf_falcon_slopes(self):
"""Verify slopes match HF Falcon (the bfloat16-free part of their implementation).

HF Falcon applies a bfloat16 cast to slopes before multiplying with positions,
so full tensor values diverge slightly. We verify that the underlying slope
values (which determine the relative bias per head) are identical.
"""
from transformers.models.falcon.modeling_falcon import (
build_alibi_tensor as hf_falcon_alibi,
)

mask = torch.ones(1, 4, dtype=torch.long)
for n_heads in [8, 16, 32]:
ours = build_alibi_tensor(mask, n_heads, torch.float32)
hf = hf_falcon_alibi(mask, n_heads, torch.float32)
# Extract slope per head from position 1 (avoids bfloat16 compounding)
ours_slopes = ours.reshape(n_heads, 1, 4)[:, 0, 1]
hf_slopes = hf[:, 0, 1]
assert torch.allclose(
ours_slopes, hf_slopes, rtol=0.01
), f"Slope mismatch for {n_heads} heads: max_diff={(ours_slopes - hf_slopes).abs().max()}"

def test_matches_hf_bloom(self):
"""Verify against HuggingFace Bloom's ALiBi implementation."""
from transformers.models.bloom.modeling_bloom import (
build_alibi_tensor as hf_bloom_alibi,
)

mask = torch.ones(1, 16, dtype=torch.long)
for n_heads in [8, 16, 32]:
ours = build_alibi_tensor(mask, n_heads, torch.float32)
hf = hf_bloom_alibi(mask, n_heads, torch.float32)
ours_flat = ours.reshape(n_heads, 1, 16)
assert torch.allclose(
ours_flat, hf, atol=1e-5
), f"Mismatch for {n_heads} heads: max_diff={(ours_flat - hf).abs().max()}"

def test_dtype_preserved(self):
mask = torch.ones(1, 4, dtype=torch.long)
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
alibi = build_alibi_tensor(mask, 4, dtype)
assert alibi.dtype == dtype
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Unit tests for FalconALiBiAttentionBridge.

Exercises the reimplemented ALiBi attention with mock weights — no model download needed.
Covers MHA, MQA, and GQA head configurations to catch shape mismatches.
"""

import torch

from transformer_lens.model_bridge.generalized_components.falcon_alibi_attention import (
FalconALiBiAttentionBridge,
)


class _MockConfig:
"""Minimal config for FalconALiBiAttentionBridge."""

def __init__(self, n_heads: int, d_model: int, n_key_value_heads: int | None = None):
self.n_heads = n_heads
self.d_model = d_model
self.n_key_value_heads = n_key_value_heads


class _MockAttention(torch.nn.Module):
"""Stub original component so the bridge's forward doesn't raise."""

def __init__(self):
super().__init__()
self.attn_dropout = torch.nn.Identity()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def _build_bridge(
n_heads: int, d_model: int, n_key_value_heads: int | None = None
) -> FalconALiBiAttentionBridge:
"""Build a wired-up FalconALiBiAttentionBridge with random Q/K/V weights."""
cfg = _MockConfig(n_heads, d_model, n_key_value_heads)
head_dim = d_model // n_heads
n_kv = n_key_value_heads or n_heads

q_linear = torch.nn.Linear(d_model, n_heads * head_dim)
k_linear = torch.nn.Linear(d_model, n_kv * head_dim)
v_linear = torch.nn.Linear(d_model, n_kv * head_dim)
o_linear = torch.nn.Linear(d_model, d_model)

def split_qkv(_component):
return q_linear, k_linear, v_linear

bridge = FalconALiBiAttentionBridge(
name="self_attention",
config=cfg,
split_qkv_matrix=split_qkv,
)
mock_attn = _MockAttention()
mock_attn.dense = o_linear
bridge.set_original_component(mock_attn)
return bridge


def _random_inputs(bridge: FalconALiBiAttentionBridge, batch: int = 2, seq: int = 6):
"""Generate random inputs via the bridge's own method."""
return bridge.get_random_inputs(batch_size=batch, seq_len=seq)


class TestFalconALiBiForward:
"""Forward pass runs and produces valid output for all head configs."""

def test_mha_forward(self):
"""Standard MHA: n_heads == n_kv_heads."""
bridge = _build_bridge(n_heads=4, d_model=32)
inputs = _random_inputs(bridge)
with torch.no_grad():
output, weights = bridge(
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
)
assert output.shape == (2, 6, 32)
assert not torch.isnan(output).any()

def test_mqa_forward(self):
"""Multi-query: K/V have 1 head, Q has n_heads."""
bridge = _build_bridge(n_heads=8, d_model=64, n_key_value_heads=1)
inputs = _random_inputs(bridge)
with torch.no_grad():
output, weights = bridge(
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
)
assert output.shape == (2, 6, 64)
assert not torch.isnan(output).any()
# Attention weights should have full n_heads after expansion
assert weights.shape[1] == 8

def test_gqa_forward(self):
"""Grouped-query: K/V have fewer heads than Q (but more than 1)."""
bridge = _build_bridge(n_heads=8, d_model=64, n_key_value_heads=2)
inputs = _random_inputs(bridge)
with torch.no_grad():
output, weights = bridge(
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
)
assert output.shape == (2, 6, 64)
assert not torch.isnan(output).any()
assert weights.shape[1] == 8


class TestALiBiEffect:
"""ALiBi bias actually affects attention scores."""

def test_alibi_changes_output(self):
"""Output with ALiBi should differ from output without."""
bridge = _build_bridge(n_heads=4, d_model=32)
inputs = _random_inputs(bridge)
hidden = inputs["hidden_states"]
mask = inputs["attention_mask"]

with torch.no_grad():
out_with, _ = bridge(hidden, alibi=inputs["alibi"], attention_mask=mask)
out_without, _ = bridge(hidden, attention_mask=mask)

assert not torch.allclose(out_with, out_without), "ALiBi should change the output"

def test_pattern_is_causal(self):
"""Upper triangle of attention pattern should be zero (causal masking)."""
bridge = _build_bridge(n_heads=4, d_model=32)
inputs = _random_inputs(bridge, batch=1, seq=4)

with torch.no_grad():
_, weights = bridge(
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
)
# weights: [batch, heads, seq, seq] — upper triangle (above diagonal) should be 0
upper = torch.triu(weights[0, 0], diagonal=1)
assert (upper == 0).all()


class TestHooksFireInForward:
"""Hooks fire correctly during the reimplemented attention forward."""

def test_attn_scores_hook(self):
bridge = _build_bridge(n_heads=4, d_model=32)
inputs = _random_inputs(bridge, batch=1, seq=4)
captured = {}

def hook_fn(tensor, hook):
captured["attn_scores"] = tensor.clone()
return tensor

bridge.hook_attn_scores.add_hook(hook_fn)
with torch.no_grad():
bridge(
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
)
assert "attn_scores" in captured
assert captured["attn_scores"].shape == (1, 4, 4, 4)

def test_pattern_hook(self):
bridge = _build_bridge(n_heads=4, d_model=32)
inputs = _random_inputs(bridge, batch=1, seq=4)
captured = {}

def hook_fn(tensor, hook):
captured["pattern"] = tensor.clone()
return tensor

bridge.hook_pattern.add_hook(hook_fn)
with torch.no_grad():
bridge(
inputs["hidden_states"], **{k: v for k, v in inputs.items() if k != "hidden_states"}
)
assert "pattern" in captured
# Pattern rows should sum to 1 (softmax output)
row_sums = captured["pattern"].sum(dim=-1)
assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5)
2 changes: 2 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ApertusArchitectureAdapter,
BertArchitectureAdapter,
BloomArchitectureAdapter,
FalconArchitectureAdapter,
Gemma1ArchitectureAdapter,
Gemma2ArchitectureAdapter,
Gemma3ArchitectureAdapter,
Expand Down Expand Up @@ -52,6 +53,7 @@
"ApertusForCausalLM": ApertusArchitectureAdapter,
"BertForMaskedLM": BertArchitectureAdapter,
"BloomForCausalLM": BloomArchitectureAdapter,
"FalconForCausalLM": FalconArchitectureAdapter,
"GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
"Gemma1ForCausalLM": Gemma1ArchitectureAdapter,
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
Expand Down
Loading
Loading