Skip to content

Commit d3ee9c3

Browse files
author
Igor Morozov
committed
fix: resolve NameError when MultiHeadAttention is called with w_init=None
`w_init_scale` was referenced but never defined, causing a NameError whenever MultiHeadAttention is instantiated with the default w_init=None. Fix replaces the undefined variable with the literal 1.0, which matches the upstream haiku VarianceScaling default. Adds two focused regression tests: - test_w_init_none_does_not_raise: exercises the formerly-broken code path - test_w_init_explicit_still_works: confirms explicit w_init is unaffected Fixes #11
1 parent b062a56 commit d3ee9c3

2 files changed

Lines changed: 56 additions & 1 deletion

File tree

crystalformer/src/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
self.dropout_rate = dropout_rate
8585

8686
if w_init is None:
87-
w_init = hk.initializers.VarianceScaling(w_init_scale)
87+
w_init = hk.initializers.VarianceScaling(1.0)
8888
self.w_init = w_init
8989
self.with_bias = with_bias
9090
self.b_init = b_init

tests/test_attention.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Tests for MultiHeadAttention -- focuses on the w_init=None fix.
2+
3+
Before the fix, calling MultiHeadAttention with w_init=None (the default)
4+
raised NameError: name 'w_init_scale' is not defined.
5+
"""
6+
from config import *
7+
8+
from crystalformer.src.attention import MultiHeadAttention
9+
10+
11+
def test_w_init_none_does_not_raise():
12+
"""w_init=None (the default) must not raise NameError.
13+
14+
Regression test for: NameError: name 'w_init_scale' is not defined.
15+
The fix replaces the undefined variable with the literal value 1.0,
16+
matching the upstream haiku default for VarianceScaling.
17+
"""
18+
def fn(q, k, v):
19+
mha = MultiHeadAttention(
20+
num_heads=2,
21+
key_size=8,
22+
model_size=16,
23+
w_init=None, # default -- was broken before fix
24+
)
25+
return mha(q, k, v)
26+
27+
f = hk.without_apply_rng(hk.transform(fn))
28+
key = jax.random.PRNGKey(0)
29+
x = jax.random.normal(key, (4, 16))
30+
params = f.init(key, x, x, x)
31+
out = f.apply(params, x, x, x)
32+
33+
assert out.shape == (4, 16)
34+
assert jnp.isfinite(out).all(), "w_init=None path produces NaN/Inf"
35+
36+
37+
def test_w_init_explicit_still_works():
38+
"""Explicit w_init continues to work after the fix."""
39+
def fn(q, k, v):
40+
mha = MultiHeadAttention(
41+
num_heads=2,
42+
key_size=8,
43+
model_size=16,
44+
w_init=hk.initializers.VarianceScaling(1.0),
45+
)
46+
return mha(q, k, v)
47+
48+
f = hk.without_apply_rng(hk.transform(fn))
49+
key = jax.random.PRNGKey(1)
50+
x = jax.random.normal(key, (4, 16))
51+
params = f.init(key, x, x, x)
52+
out = f.apply(params, x, x, x)
53+
54+
assert out.shape == (4, 16)
55+
assert jnp.isfinite(out).all()

0 commit comments

Comments
 (0)