From 54a2862df841d6847de9e59e252270dc2ccf5f47 Mon Sep 17 00:00:00 2001 From: Igor Morozov Date: Wed, 1 Apr 2026 16:41:23 +0300 Subject: [PATCH] feat: add compositional guided sampling (--composition_bias) Soft bias for atom-type logits during autoregressive decoding. Steers generation toward desired compositions without retraining. - parse_composition_bias() in elements.py: "Fe:2.0,S:1.5" -> array - composition_bias param in make_sample_crystal(): added to a_logit - --composition_bias CLI arg in main.py - 17 tests: parsing (13, incl. out-of-vocab regression) + sampling (4) Raise ValueError when element index exceeds atom_types vocabulary instead of silently dropping the bias entry. Usage: python main.py --composition_bias "Fe:2.0,S:1.5,O:-1.0" ... Ref: deepmodeling/CrystalFormer#13 Co-Authored-By: Claude Opus 4.6 (1M context) --- crystalformer/src/elements.py | 39 +++++++ crystalformer/src/sample.py | 9 +- main.py | 9 +- tests/test_composition_bias.py | 179 +++++++++++++++++++++++++++++++++ 4 files changed, 233 insertions(+), 3 deletions(-) create mode 100644 tests/test_composition_bias.py diff --git a/crystalformer/src/elements.py b/crystalformer/src/elements.py index 8c99a44..3477b9d 100644 --- a/crystalformer/src/elements.py +++ b/crystalformer/src/elements.py @@ -37,6 +37,45 @@ noble_gas_dict = {e: element_dict[e] for e in noble_gas} +def parse_composition_bias(bias_string, atom_types=119): + """Parse a composition bias string into a numeric array. + + Parameters: + bias_string: Comma-separated ``Element:weight`` pairs, e.g. + ``"Fe:2.0,S:1.5,O:-1.0"``. Positive weight increases + sampling probability; negative decreases. ``None`` or + empty string returns all zeros (no bias). + atom_types: Length of the returned array (default 119). + + Returns: + numpy array of shape ``(atom_types,)`` with bias values. + + Raises: + ValueError: Unknown element symbol or malformed entry. + """ + import numpy as np + bias = np.zeros(atom_types) + if not bias_string: + return bias + for item in bias_string.split(","): + parts = item.strip().split(":") + if len(parts) != 2: + raise ValueError( + f"Invalid bias format: {item!r}. Expected 'Element:weight'" + ) + element, weight = parts[0].strip(), parts[1].strip() + if element not in element_dict: + raise ValueError(f"Unknown element: {element!r}") + idx = element_dict[element] + if idx >= atom_types: + raise ValueError( + f"Element {element!r} (index {idx}) is outside model " + f"vocabulary (atom_types={atom_types})" + ) + bias[idx] = float(weight) + return bias + + if __name__=="__main__": print (len(element_list)) print (element_dict["H"]) diff --git a/crystalformer/src/sample.py b/crystalformer/src/sample.py index b4052fb..ec75131 100644 --- a/crystalformer/src/sample.py +++ b/crystalformer/src/sample.py @@ -61,13 +61,18 @@ def sample_x(key, h_x, Kx, top_p, temperature, batchsize): return key, x -def make_sample_crystal(transformer, n_max, atom_types, wyck_types, Kx, Kl, w_mask, top_p, temperature, K=0, g=None, atom_mask=None, spg_mask=None): +def make_sample_crystal(transformer, n_max, atom_types, wyck_types, Kx, Kl, w_mask, top_p, temperature, K=0, g=None, atom_mask=None, spg_mask=None, composition_bias=None): if atom_mask is None: user_atom_mask = jnp.ones((atom_types,), dtype=bool) else: user_atom_mask = atom_mask.astype(bool) + if composition_bias is not None: + _composition_bias = jnp.asarray(composition_bias, dtype=jnp.float32) + else: + _composition_bias = jnp.zeros((atom_types,)) + @partial(jax.jit, static_argnums=2) def sample_crystal(key, params, batchsize, composition): @@ -98,7 +103,7 @@ def body_fn(i, state): a_logit = h_al[:, :atom_types] key, subkey = jax.random.split(key) - a_logit = a_logit + jnp.where(atom_mask, 0.0, -1e10) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp) + a_logit = a_logit + jnp.where(atom_mask, 0.0, -1e10) + _composition_bias # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp) a = sample_top_p(subkey, a_logit, top_p, temperature) A = A.at[:, i].set(a) diff --git a/main.py b/main.py index f117f36..5a78bb5 100644 --- a/main.py +++ b/main.py @@ -77,6 +77,7 @@ group.add_argument('--output_filename', type=str, default='output.csv', help='outfile to save sampled structures') group.add_argument('--verbose', type=int, default=0, help='verbose level') group.add_argument('--remove_radioactive', action='store_true', help='remove radioactive elements and noble gas, only valid when formula is None') +group.add_argument('--composition_bias', type=str, default=None, help='Soft bias for atom types during sampling, e.g. "Fe:2.0,S:1.5,O:-1.0". Positive values increase probability, negative decrease.') args = parser.parse_args() @@ -205,7 +206,13 @@ else: print ('targeting spacegroup No.', args.spacegroup) - sample_crystal = make_sample_crystal(transformer, args.n_max, args.atom_types, args.wyck_types, args.Kx, args.Kl, w_mask, args.top_p, args.temperature, args.K, args.spacegroup, atom_mask) + comp_bias = None + if args.composition_bias is not None: + from crystalformer.src.elements import parse_composition_bias + comp_bias = jnp.array(parse_composition_bias(args.composition_bias, args.atom_types)) + print('composition bias:', args.composition_bias) + + sample_crystal = make_sample_crystal(transformer, args.n_max, args.atom_types, args.wyck_types, args.Kx, args.Kl, w_mask, args.top_p, args.temperature, args.K, args.spacegroup, atom_mask, composition_bias=comp_bias) if args.seed is not None: key = jax.random.PRNGKey(args.seed) # reset key for sampling if seed is provided diff --git a/tests/test_composition_bias.py b/tests/test_composition_bias.py new file mode 100644 index 0000000..32af3e5 --- /dev/null +++ b/tests/test_composition_bias.py @@ -0,0 +1,179 @@ +"""Tests for compositional guided sampling. + +Tests are split into two groups: +- Parsing tests: pure Python/numpy, no JAX required. +- Sampling integration tests: require JAX + haiku (skipped if unavailable). +""" + +import pytest +import numpy as np + + +# ------------------------------------------------------------------ # +# Parsing tests (no JAX needed) # +# ------------------------------------------------------------------ # + +from crystalformer.src.elements import element_list, element_dict, parse_composition_bias + + +class TestParseCompositionBias: + """parse_composition_bias('Fe:2.0,S:1.5') -> float array.""" + + def test_single_element(self): + bias = parse_composition_bias("Fe:2.0") + assert bias[element_dict["Fe"]] == pytest.approx(2.0) + assert bias.sum() == pytest.approx(2.0) # only Fe is nonzero + + def test_multiple_elements(self): + bias = parse_composition_bias("Fe:2.0,S:1.5,Ni:0.5") + assert bias[element_dict["Fe"]] == pytest.approx(2.0) + assert bias[element_dict["S"]] == pytest.approx(1.5) + assert bias[element_dict["Ni"]] == pytest.approx(0.5) + + def test_negative_bias(self): + bias = parse_composition_bias("O:-1.0") + assert bias[element_dict["O"]] == pytest.approx(-1.0) + + def test_empty_string_returns_zeros(self): + bias = parse_composition_bias("") + assert np.all(bias == 0) + + def test_none_returns_zeros(self): + bias = parse_composition_bias(None) + assert np.all(bias == 0) + + def test_output_shape(self): + bias = parse_composition_bias("Li:1.0") + assert bias.shape == (119,) + + def test_custom_atom_types(self): + bias = parse_composition_bias("H:1.0", atom_types=10) + assert bias.shape == (10,) + assert bias[element_dict["H"]] == pytest.approx(1.0) + + def test_unknown_element_raises(self): + with pytest.raises(ValueError, match="Unknown element"): + parse_composition_bias("Xx:1.0") + + def test_bad_format_raises(self): + with pytest.raises(ValueError, match="Invalid bias format"): + parse_composition_bias("Fe2.0") # missing colon + + def test_element_outside_atom_types_raises(self): + """Regression: element index >= atom_types must raise, not silently drop.""" + with pytest.raises(ValueError, match="outside model vocabulary"): + parse_composition_bias("Fe:2.0", atom_types=10) # Fe is index 26 + + def test_whitespace_tolerance(self): + bias = parse_composition_bias(" Fe : 2.0 , S : 1.5 ") + assert bias[element_dict["Fe"]] == pytest.approx(2.0) + assert bias[element_dict["S"]] == pytest.approx(1.5) + + def test_zero_bias_is_noop(self): + bias = parse_composition_bias("Fe:0.0") + assert np.all(bias == 0) + + def test_padding_element_unaffected(self): + """Index 0 ('X', padding) should stay zero regardless of input.""" + bias = parse_composition_bias("Fe:2.0,S:1.5") + assert bias[0] == 0.0 + + +# ------------------------------------------------------------------ # +# Sampling integration tests (need JAX) # +# ------------------------------------------------------------------ # + +try: + import jax + import jax.numpy as jnp + HAS_JAX = True +except ImportError: + HAS_JAX = False + + +@pytest.mark.skipif(not HAS_JAX, reason="JAX not installed") +class TestBiasAppliedToLogits: + """Verify that composition_bias shifts atom sampling distribution.""" + + def test_positive_bias_increases_probability(self): + """Element with positive bias should be sampled more often.""" + from crystalformer.src.sample import sample_top_p + + key = jax.random.PRNGKey(42) + n_samples = 5000 + n_types = 10 + + # Uniform logits + logits = jnp.zeros((n_samples, n_types)) + + # No bias -> roughly uniform + counts_no_bias = np.bincount( + np.array(sample_top_p(key, logits, 1.0, 1.0)), minlength=n_types + ) + + # Bias toward element 3 + bias = jnp.zeros(n_types).at[3].set(3.0) + biased_logits = logits + bias + counts_biased = np.bincount( + np.array(sample_top_p(key, biased_logits, 1.0, 1.0)), + minlength=n_types, + ) + + # Element 3 should be sampled much more with bias + assert counts_biased[3] > counts_no_bias[3] * 1.5 + + def test_negative_bias_decreases_probability(self): + """Element with negative bias should be sampled less often.""" + from crystalformer.src.sample import sample_top_p + + key = jax.random.PRNGKey(42) + n_samples = 5000 + n_types = 10 + + logits = jnp.zeros((n_samples, n_types)) + + counts_no_bias = np.bincount( + np.array(sample_top_p(key, logits, 1.0, 1.0)), minlength=n_types + ) + + bias = jnp.zeros(n_types).at[3].set(-3.0) + biased_logits = logits + bias + counts_biased = np.bincount( + np.array(sample_top_p(key, biased_logits, 1.0, 1.0)), + minlength=n_types, + ) + + assert counts_biased[3] < counts_no_bias[3] * 0.5 + + def test_atom_mask_overrides_positive_bias(self): + """A masked-out element (atom_mask=False) should stay blocked + even with large positive bias.""" + from crystalformer.src.sample import sample_top_p + + key = jax.random.PRNGKey(0) + n_samples = 2000 + n_types = 5 + + logits = jnp.zeros((n_samples, n_types)) + + # Mask out element 2 (hard block) + atom_mask = jnp.array([True, True, False, True, True]) + mask_penalty = jnp.where(atom_mask, 0.0, -1e10) + + # Large positive bias on the masked element + bias = jnp.zeros(n_types).at[2].set(10.0) + + final_logits = logits + mask_penalty + bias + samples = np.array(sample_top_p(key, final_logits, 1.0, 1.0)) + + # Element 2 should never be sampled + assert np.sum(samples == 2) == 0 + + def test_make_sample_crystal_accepts_composition_bias(self): + """make_sample_crystal should accept composition_bias parameter + without error (smoke test — does not run sampling).""" + from crystalformer.src.sample import make_sample_crystal + + import inspect + sig = inspect.signature(make_sample_crystal) + assert "composition_bias" in sig.parameters