Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions tests/integration/model_bridge/test_deepseek_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Integration tests for DeepSeek V3 architecture adapter."""

import tempfile

import pytest
import torch
from transformers import AutoTokenizer, DeepseekV3Config, DeepseekV3ForCausalLM

from transformer_lens.model_bridge.bridge import TransformerBridge


@pytest.fixture(scope="module")
def tiny_deepseek_bridge():
tiny_config = DeepseekV3Config(
hidden_size=256,
intermediate_size=512,
num_hidden_layers=4,
num_attention_heads=8,
q_lora_rank=64,
kv_lora_rank=32,
qk_nope_head_dim=16,
qk_rope_head_dim=8,
v_head_dim=16,
vocab_size=1000,
first_k_dense_replace=1,
n_routed_experts=8,
n_shared_experts=1,
num_experts_per_tok=2,
n_group=2,
topk_group=1,
max_position_embeddings=128,
moe_intermediate_size=256,
)
hf_model = DeepseekV3ForCausalLM(tiny_config)

with tempfile.TemporaryDirectory() as tmpdir:
hf_model.save_pretrained(tmpdir)
tok = AutoTokenizer.from_pretrained("gpt2")
tok.save_pretrained(tmpdir)
bridge = TransformerBridge.boot_transformers(tmpdir, device="cpu")
yield bridge


class TestDeepSeekBridgeCreation:
def test_bridge_has_correct_block_count(self, tiny_deepseek_bridge):
assert len(tiny_deepseek_bridge.blocks) == 4

def test_bridge_has_embed_and_unembed(self, tiny_deepseek_bridge):
assert hasattr(tiny_deepseek_bridge, "embed")
assert hasattr(tiny_deepseek_bridge, "unembed")
assert hasattr(tiny_deepseek_bridge, "ln_final")

def test_attention_is_mla(self, tiny_deepseek_bridge):
from transformer_lens.model_bridge.generalized_components.mla_attention import (
MLAAttentionBridge,
)

assert isinstance(tiny_deepseek_bridge.blocks[0].attn, MLAAttentionBridge)


class TestDeepSeekForwardPass:
def test_forward_returns_logits(self, tiny_deepseek_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
with torch.no_grad():
output = tiny_deepseek_bridge(tokens)
assert output.shape == (1, 4, 1000)
assert not torch.isnan(output).any()
assert not torch.isinf(output).any()

def test_forward_matches_hf(self, tiny_deepseek_bridge):
"""SDPA vs manual matmul — small float32 differences expected."""
tokens = torch.tensor([[1, 2, 3, 4]])
hf_model = tiny_deepseek_bridge.original_model
with torch.no_grad():
bridge_out = tiny_deepseek_bridge(tokens)
hf_out = hf_model(tokens).logits
max_diff = (bridge_out - hf_out).abs().max().item()
assert max_diff < 0.15, f"Bridge vs HF max diff = {max_diff}"


class TestDeepSeekDenseVsMoELayers:
def test_dense_layer_has_no_moe_hooks(self, tiny_deepseek_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = tiny_deepseek_bridge.run_with_cache(tokens)
cache_keys = set(cache.keys())
assert not any("blocks.0.mlp.gate" in k for k in cache_keys)
assert not any("blocks.0.mlp.shared_experts" in k for k in cache_keys)

def test_moe_layer_has_gate_hooks(self, tiny_deepseek_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = tiny_deepseek_bridge.run_with_cache(tokens)
assert any("blocks.1.mlp.gate" in k for k in cache.keys())

def test_moe_layer_has_shared_experts_hooks(self, tiny_deepseek_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = tiny_deepseek_bridge.run_with_cache(tokens)
assert any("blocks.1.mlp.shared_experts" in k for k in cache.keys())

def test_both_layers_have_mlp_hooks(self, tiny_deepseek_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = tiny_deepseek_bridge.run_with_cache(tokens)
for i in [0, 1]:
assert f"blocks.{i}.mlp.hook_in" in cache
assert f"blocks.{i}.mlp.hook_out" in cache

def test_both_layers_produce_non_nan(self, tiny_deepseek_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = tiny_deepseek_bridge.run_with_cache(tokens)
for i in [0, 1]:
assert not torch.isnan(cache[f"blocks.{i}.mlp.hook_out"]).any()


class TestDeepSeekAttentionHooks:
def test_attention_hooks_fire_all_layers(self, tiny_deepseek_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = tiny_deepseek_bridge.run_with_cache(tokens)
for i in range(4):
assert f"blocks.{i}.attn.hook_in" in cache
assert f"blocks.{i}.attn.hook_out" in cache

def test_mla_latent_hooks_fire(self, tiny_deepseek_bridge):
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = tiny_deepseek_bridge.run_with_cache(tokens)
assert any("hook_q_latent" in k for k in cache.keys())
assert any("hook_kv_latent" in k for k in cache.keys())
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Unit tests for MLAAttentionBridge (DeepSeek Multi-Head Latent Attention)."""

import pytest
import torch
from transformers import DeepseekV3Config, DeepseekV3ForCausalLM

from transformer_lens.model_bridge.generalized_components.mla_attention import (
MLAAttentionBridge,
)


@pytest.fixture(scope="module")
def tiny_config():
return DeepseekV3Config(
hidden_size=256,
intermediate_size=512,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=1,
q_lora_rank=64,
kv_lora_rank=32,
qk_nope_head_dim=16,
qk_rope_head_dim=8,
v_head_dim=16,
vocab_size=1000,
first_k_dense_replace=1,
n_routed_experts=8,
n_shared_experts=1,
num_experts_per_tok=2,
n_group=2,
topk_group=1,
max_position_embeddings=128,
moe_intermediate_size=256,
)


@pytest.fixture(scope="module")
def tiny_model(tiny_config):
return DeepseekV3ForCausalLM(tiny_config)


@pytest.fixture(scope="module")
def hf_attn(tiny_model):
return tiny_model.model.layers[0].self_attn


@pytest.fixture(scope="module")
def mla_bridge(tiny_config, hf_attn, tiny_model):
bridge = MLAAttentionBridge(name="self_attn", config=tiny_config, submodules={})
bridge.set_original_component(hf_attn)
bridge.set_rotary_emb(tiny_model.model.rotary_emb)
return bridge


class TestMLAAttentionBridgeHooks:
def test_all_expected_hooks_exist(self, mla_bridge):
for hook_name in [
"hook_in",
"hook_out",
"hook_q_latent",
"hook_kv_latent",
"hook_q",
"hook_k",
"hook_v",
"hook_rot_q",
"hook_rot_k",
"hook_attn_scores",
"hook_pattern",
"hook_cos",
"hook_sin",
]:
assert hasattr(mla_bridge, hook_name), f"Missing hook: {hook_name}"

def test_W_Q_raises_not_implemented(self, mla_bridge):
with pytest.raises(NotImplementedError, match="not available on MLA"):
_ = mla_bridge.W_Q

def test_W_K_raises_not_implemented(self, mla_bridge):
with pytest.raises(NotImplementedError, match="not available on MLA"):
_ = mla_bridge.W_K


class TestMLAAttentionBridgeForward:
@pytest.fixture
def sample_inputs(self, tiny_config, tiny_model):
batch, seq = 2, 8
hidden_states = torch.randn(batch, seq, tiny_config.hidden_size)
position_ids = torch.arange(seq).unsqueeze(0).expand(batch, -1)
cos, sin = tiny_model.model.rotary_emb(hidden_states, position_ids)
return hidden_states, (cos, sin)

def test_output_matches_hf(self, mla_bridge, hf_attn, sample_inputs, tiny_model):
"""HF uses SDPA, bridge uses manual matmul — small float32 differences expected."""
hidden_states, position_embeddings = sample_inputs

with torch.no_grad():
hf_attn_out = hf_attn(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=None,
)[0]
bridge_attn_out = mla_bridge(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=None,
)[0]

max_diff = (hf_attn_out - bridge_attn_out).abs().max().item()
mean_diff = (hf_attn_out - bridge_attn_out).abs().mean().item()
assert max_diff < 0.15, f"Output too different: max diff = {max_diff}"
assert mean_diff < 0.02, f"Output too different: mean diff = {mean_diff}"

def test_hooks_fire_and_have_correct_shapes(self, mla_bridge, sample_inputs, tiny_config):
hidden_states, position_embeddings = sample_inputs
batch, seq = hidden_states.shape[:2]
captured = {}

hooks_to_check = [
"hook_q_latent",
"hook_kv_latent",
"hook_q",
"hook_k",
"hook_v",
"hook_rot_q",
"hook_rot_k",
"hook_attn_scores",
"hook_pattern",
]

handles = []
for name in hooks_to_check:

def make_capture(n):
def hook_fn(module, input, output):
captured[n] = output.shape

return hook_fn

handles.append(getattr(mla_bridge, name).register_forward_hook(make_capture(name)))

try:
with torch.no_grad():
mla_bridge(
hidden_states, position_embeddings=position_embeddings, attention_mask=None
)
finally:
for h in handles:
h.remove()

n_heads = tiny_config.num_attention_heads
qk_head_dim = tiny_config.qk_nope_head_dim + tiny_config.qk_rope_head_dim

for name in hooks_to_check:
assert name in captured, f"Hook {name} did not fire"

assert captured["hook_q_latent"] == (batch, seq, tiny_config.q_lora_rank)
assert captured["hook_kv_latent"] == (batch, seq, tiny_config.kv_lora_rank)
assert captured["hook_q"] == (batch, n_heads, seq, qk_head_dim)
assert captured["hook_k"] == (batch, n_heads, seq, qk_head_dim)
assert captured["hook_v"] == (batch, n_heads, seq, tiny_config.v_head_dim)
assert captured["hook_attn_scores"] == (batch, n_heads, seq, seq)
assert captured["hook_pattern"] == (batch, n_heads, seq, seq)

def test_hook_q_is_post_rope(self, mla_bridge, sample_inputs):
"""hook_q's rope portion should match hook_rot_q."""
hidden_states, position_embeddings = sample_inputs
q_values: list[torch.Tensor] = []
rot_q_values: list[torch.Tensor] = []

h1 = mla_bridge.hook_q.register_forward_hook(lambda m, i, o: q_values.append(o.clone()))
h2 = mla_bridge.hook_rot_q.register_forward_hook(
lambda m, i, o: rot_q_values.append(o.clone())
)

try:
with torch.no_grad():
mla_bridge(
hidden_states, position_embeddings=position_embeddings, attention_mask=None
)
finally:
h1.remove()
h2.remove()

qk_rope_dim = mla_bridge._qk_rope_head_dim
assert torch.allclose(q_values[0][..., -qk_rope_dim:], rot_q_values[0], atol=1e-5)
34 changes: 34 additions & 0 deletions transformer_lens/benchmarks/component_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,17 @@ def _test_component_recursive(
if last_part in ["o", "out"]:
return

# Skip MLA intermediates (expect compressed-dim inputs, not hidden_states)
if last_part in [
"q_a_proj",
"q_a_layernorm",
"q_b_proj",
"kv_a_proj_with_mqa",
"kv_a_layernorm",
"kv_b_proj",
]:
return

# Skip virtual splits from fused projections (no standalone HF equivalent)
if last_part in ["q", "k", "v", "gate", "in"]:
parent_path = ".".join(path_parts[:-1])
Expand All @@ -459,6 +470,29 @@ def _test_component_recursive(
except Exception:
pass

# Skip components not wired on this layer (per-layer or per-config variation).
# Only report as failure if the HF model has it but the bridge doesn't.
try:
self.adapter.get_component(self.bridge_model, component_path)
except (AttributeError, ValueError):
parts = component_path.split(".")
if len(parts) >= 3 and parts[1].isdigit():
subpath = ".".join([parts[0]] + ["{layer}"] + parts[2:])
# Per-layer variation: exists on some other layer (e.g., MoE vs dense)
for probe_layer in range(self.cfg.n_layers):
probe_path = subpath.replace("{layer}", str(probe_layer))
try:
self.adapter.get_component(self.bridge_model, probe_path)
return # Found on another layer — skip this one
except (AttributeError, ValueError):
continue
# Per-config absence: HF model also lacks it (e.g., q_lora_rank=None)
try:
self.adapter.get_component(self.hf_model, component_path)
except (AttributeError, ValueError):
return
# Bridge is missing a component that HF has — likely misconfiguration

# Test this component
result = self._test_component(component_path, component, test_inputs)
if result is not None:
Expand Down
2 changes: 2 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ApertusArchitectureAdapter,
BertArchitectureAdapter,
BloomArchitectureAdapter,
DeepSeekV3ArchitectureAdapter,
Gemma1ArchitectureAdapter,
Gemma2ArchitectureAdapter,
Gemma3ArchitectureAdapter,
Expand Down Expand Up @@ -52,6 +53,7 @@
"ApertusForCausalLM": ApertusArchitectureAdapter,
"BertForMaskedLM": BertArchitectureAdapter,
"BloomForCausalLM": BloomArchitectureAdapter,
"DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter,
"GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
"Gemma1ForCausalLM": Gemma1ArchitectureAdapter,
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
Expand Down
Loading
Loading