Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions crystalformer/src/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,45 @@
noble_gas_dict = {e: element_dict[e] for e in noble_gas}


def parse_composition_bias(bias_string, atom_types=119):
"""Parse a composition bias string into a numeric array.

Parameters:
bias_string: Comma-separated ``Element:weight`` pairs, e.g.
``"Fe:2.0,S:1.5,O:-1.0"``. Positive weight increases
sampling probability; negative decreases. ``None`` or
empty string returns all zeros (no bias).
atom_types: Length of the returned array (default 119).

Returns:
numpy array of shape ``(atom_types,)`` with bias values.

Raises:
ValueError: Unknown element symbol or malformed entry.
"""
import numpy as np
bias = np.zeros(atom_types)
if not bias_string:
return bias
for item in bias_string.split(","):
parts = item.strip().split(":")
if len(parts) != 2:
raise ValueError(
f"Invalid bias format: {item!r}. Expected 'Element:weight'"
)
element, weight = parts[0].strip(), parts[1].strip()
if element not in element_dict:
raise ValueError(f"Unknown element: {element!r}")
idx = element_dict[element]
if idx >= atom_types:
raise ValueError(
f"Element {element!r} (index {idx}) is outside model "
f"vocabulary (atom_types={atom_types})"
)
bias[idx] = float(weight)
return bias


if __name__=="__main__":
print (len(element_list))
print (element_dict["H"])
Expand Down
9 changes: 7 additions & 2 deletions crystalformer/src/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,18 @@ def sample_x(key, h_x, Kx, top_p, temperature, batchsize):
return key, x


def make_sample_crystal(transformer, n_max, atom_types, wyck_types, Kx, Kl, w_mask, top_p, temperature, K=0, g=None, atom_mask=None, spg_mask=None):
def make_sample_crystal(transformer, n_max, atom_types, wyck_types, Kx, Kl, w_mask, top_p, temperature, K=0, g=None, atom_mask=None, spg_mask=None, composition_bias=None):

if atom_mask is None:
user_atom_mask = jnp.ones((atom_types,), dtype=bool)
else:
user_atom_mask = atom_mask.astype(bool)

if composition_bias is not None:
_composition_bias = jnp.asarray(composition_bias, dtype=jnp.float32)
else:
_composition_bias = jnp.zeros((atom_types,))

@partial(jax.jit, static_argnums=2)
def sample_crystal(key, params, batchsize, composition):

Expand Down Expand Up @@ -98,7 +103,7 @@ def body_fn(i, state):
a_logit = h_al[:, :atom_types]

key, subkey = jax.random.split(key)
a_logit = a_logit + jnp.where(atom_mask, 0.0, -1e10) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp)
a_logit = a_logit + jnp.where(atom_mask, 0.0, -1e10) + _composition_bias # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp)
a = sample_top_p(subkey, a_logit, top_p, temperature)
A = A.at[:, i].set(a)

Expand Down
9 changes: 8 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
group.add_argument('--output_filename', type=str, default='output.csv', help='outfile to save sampled structures')
group.add_argument('--verbose', type=int, default=0, help='verbose level')
group.add_argument('--remove_radioactive', action='store_true', help='remove radioactive elements and noble gas, only valid when formula is None')
group.add_argument('--composition_bias', type=str, default=None, help='Soft bias for atom types during sampling, e.g. "Fe:2.0,S:1.5,O:-1.0". Positive values increase probability, negative decrease.')


args = parser.parse_args()
Expand Down Expand Up @@ -205,7 +206,13 @@
else:
print ('targeting spacegroup No.', args.spacegroup)

sample_crystal = make_sample_crystal(transformer, args.n_max, args.atom_types, args.wyck_types, args.Kx, args.Kl, w_mask, args.top_p, args.temperature, args.K, args.spacegroup, atom_mask)
comp_bias = None
if args.composition_bias is not None:
from crystalformer.src.elements import parse_composition_bias
comp_bias = jnp.array(parse_composition_bias(args.composition_bias, args.atom_types))
print('composition bias:', args.composition_bias)

sample_crystal = make_sample_crystal(transformer, args.n_max, args.atom_types, args.wyck_types, args.Kx, args.Kl, w_mask, args.top_p, args.temperature, args.K, args.spacegroup, atom_mask, composition_bias=comp_bias)

if args.seed is not None:
key = jax.random.PRNGKey(args.seed) # reset key for sampling if seed is provided
Expand Down
179 changes: 179 additions & 0 deletions tests/test_composition_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Tests for compositional guided sampling.

Tests are split into two groups:
- Parsing tests: pure Python/numpy, no JAX required.
- Sampling integration tests: require JAX + haiku (skipped if unavailable).
"""

import pytest
import numpy as np


# ------------------------------------------------------------------ #
# Parsing tests (no JAX needed) #
# ------------------------------------------------------------------ #

from crystalformer.src.elements import element_list, element_dict, parse_composition_bias


class TestParseCompositionBias:
"""parse_composition_bias('Fe:2.0,S:1.5') -> float array."""

def test_single_element(self):
bias = parse_composition_bias("Fe:2.0")
assert bias[element_dict["Fe"]] == pytest.approx(2.0)
assert bias.sum() == pytest.approx(2.0) # only Fe is nonzero

def test_multiple_elements(self):
bias = parse_composition_bias("Fe:2.0,S:1.5,Ni:0.5")
assert bias[element_dict["Fe"]] == pytest.approx(2.0)
assert bias[element_dict["S"]] == pytest.approx(1.5)
assert bias[element_dict["Ni"]] == pytest.approx(0.5)

def test_negative_bias(self):
bias = parse_composition_bias("O:-1.0")
assert bias[element_dict["O"]] == pytest.approx(-1.0)

def test_empty_string_returns_zeros(self):
bias = parse_composition_bias("")
assert np.all(bias == 0)

def test_none_returns_zeros(self):
bias = parse_composition_bias(None)
assert np.all(bias == 0)

def test_output_shape(self):
bias = parse_composition_bias("Li:1.0")
assert bias.shape == (119,)

def test_custom_atom_types(self):
bias = parse_composition_bias("H:1.0", atom_types=10)
assert bias.shape == (10,)
assert bias[element_dict["H"]] == pytest.approx(1.0)

def test_unknown_element_raises(self):
with pytest.raises(ValueError, match="Unknown element"):
parse_composition_bias("Xx:1.0")

def test_bad_format_raises(self):
with pytest.raises(ValueError, match="Invalid bias format"):
parse_composition_bias("Fe2.0") # missing colon

def test_element_outside_atom_types_raises(self):
"""Regression: element index >= atom_types must raise, not silently drop."""
with pytest.raises(ValueError, match="outside model vocabulary"):
parse_composition_bias("Fe:2.0", atom_types=10) # Fe is index 26

def test_whitespace_tolerance(self):
bias = parse_composition_bias(" Fe : 2.0 , S : 1.5 ")
assert bias[element_dict["Fe"]] == pytest.approx(2.0)
assert bias[element_dict["S"]] == pytest.approx(1.5)

def test_zero_bias_is_noop(self):
bias = parse_composition_bias("Fe:0.0")
assert np.all(bias == 0)

def test_padding_element_unaffected(self):
"""Index 0 ('X', padding) should stay zero regardless of input."""
bias = parse_composition_bias("Fe:2.0,S:1.5")
assert bias[0] == 0.0


# ------------------------------------------------------------------ #
# Sampling integration tests (need JAX) #
# ------------------------------------------------------------------ #

try:
import jax
import jax.numpy as jnp
HAS_JAX = True
except ImportError:
HAS_JAX = False


@pytest.mark.skipif(not HAS_JAX, reason="JAX not installed")
class TestBiasAppliedToLogits:
"""Verify that composition_bias shifts atom sampling distribution."""

def test_positive_bias_increases_probability(self):
"""Element with positive bias should be sampled more often."""
from crystalformer.src.sample import sample_top_p

key = jax.random.PRNGKey(42)
n_samples = 5000
n_types = 10

# Uniform logits
logits = jnp.zeros((n_samples, n_types))

# No bias -> roughly uniform
counts_no_bias = np.bincount(
np.array(sample_top_p(key, logits, 1.0, 1.0)), minlength=n_types
)

# Bias toward element 3
bias = jnp.zeros(n_types).at[3].set(3.0)
biased_logits = logits + bias
counts_biased = np.bincount(
np.array(sample_top_p(key, biased_logits, 1.0, 1.0)),
minlength=n_types,
)

# Element 3 should be sampled much more with bias
assert counts_biased[3] > counts_no_bias[3] * 1.5

def test_negative_bias_decreases_probability(self):
"""Element with negative bias should be sampled less often."""
from crystalformer.src.sample import sample_top_p

key = jax.random.PRNGKey(42)
n_samples = 5000
n_types = 10

logits = jnp.zeros((n_samples, n_types))

counts_no_bias = np.bincount(
np.array(sample_top_p(key, logits, 1.0, 1.0)), minlength=n_types
)

bias = jnp.zeros(n_types).at[3].set(-3.0)
biased_logits = logits + bias
counts_biased = np.bincount(
np.array(sample_top_p(key, biased_logits, 1.0, 1.0)),
minlength=n_types,
)

assert counts_biased[3] < counts_no_bias[3] * 0.5

def test_atom_mask_overrides_positive_bias(self):
"""A masked-out element (atom_mask=False) should stay blocked
even with large positive bias."""
from crystalformer.src.sample import sample_top_p

key = jax.random.PRNGKey(0)
n_samples = 2000
n_types = 5

logits = jnp.zeros((n_samples, n_types))

# Mask out element 2 (hard block)
atom_mask = jnp.array([True, True, False, True, True])
mask_penalty = jnp.where(atom_mask, 0.0, -1e10)

# Large positive bias on the masked element
bias = jnp.zeros(n_types).at[2].set(10.0)

final_logits = logits + mask_penalty + bias
samples = np.array(sample_top_p(key, final_logits, 1.0, 1.0))

# Element 2 should never be sampled
assert np.sum(samples == 2) == 0

def test_make_sample_crystal_accepts_composition_bias(self):
"""make_sample_crystal should accept composition_bias parameter
without error (smoke test — does not run sampling)."""
from crystalformer.src.sample import make_sample_crystal

import inspect
sig = inspect.signature(make_sample_crystal)
assert "composition_bias" in sig.parameters
Loading