Skip to content

feat: compositional guided sampling (--composition_bias)#14

Closed
exopoiesis wants to merge 1 commit intodeepmodeling:mainfrom
exopoiesis:feat/composition-bias-upstream
Closed

feat: compositional guided sampling (--composition_bias)#14
exopoiesis wants to merge 1 commit intodeepmodeling:mainfrom
exopoiesis:feat/composition-bias-upstream

Conversation

@exopoiesis
Copy link
Copy Markdown

Summary

Adds soft bias for atom-type logits during autoregressive decoding, steering generation toward desired compositions without retraining or fine-tuning.

This is the smallest self-contained piece from the discussion in #13. Instead of building a full screening pipeline into the core repo (which we agree would be out of scope), we extracted this single feature — it lives entirely within src/ and touches only 3 existing files.

Motivation

When generating structures for a specific application (e.g., Fe-S sulfides for ionic conductors), most random generations contain irrelevant compositions. With composition bias, we increased our Fe-S hit rate from ~12% to ~60% — a 5x improvement in useful output with zero model changes.

This is useful for any targeted generation task: battery cathodes (bias Li, transition metals), thermoelectrics (bias Bi, Te), catalysts (bias Pt-group), etc.

How it works

At each atom-type sampling step, a user-provided bias vector is added to logits before softmax:

a_logit = a_logit + atom_mask_penalty + composition_bias
  • Positive bias → element sampled more often
  • Negative bias → element sampled less often
  • Zero (default) → unchanged behavior
  • atom_mask still overrides bias (hard block wins over soft nudge)

Changes

File Change
crystalformer/src/elements.py parse_composition_bias(): "Fe:2.0,S:1.5" → numpy array
crystalformer/src/sample.py composition_bias param in make_sample_crystal()
main.py --composition_bias CLI argument
tests/test_composition_bias.py 17 tests: 13 parsing + 4 sampling integration

Usage

# CLI
python main.py --composition_bias "Fe:2.0,S:1.5,O:-1.0" ...

# Python API
sample_fn = make_sample_crystal(..., composition_bias=bias_array)

Tests

17 tests covering:

  • Parsing: single/multiple elements, negatives, whitespace, edge cases, error handling
  • Sampling integration (requires JAX): positive bias increases probability, negative decreases, atom_mask overrides bias, API signature check
  • Parsing tests run without JAX; integration tests are skipped if JAX is unavailable

Design decisions

  • Additive logit bias (not multiplicative probability scaling) — numerically stable, composable with existing mask, standard practice in LLM sampling
  • No model changes — bias is applied at inference time only
  • Raises ValueError for unknown elements or out-of-vocabulary indices (not silent drop)
  • Zero overhead when unused — composition_bias=None (default) creates a zero vector, addition is a no-op in practice

Co-Authored-By: Claude Opus 4.6 (1M context) noreply@anthropic.com

Soft bias for atom-type logits during autoregressive decoding.
Steers generation toward desired compositions without retraining.

- parse_composition_bias() in elements.py: "Fe:2.0,S:1.5" -> array
- composition_bias param in make_sample_crystal(): added to a_logit
- --composition_bias CLI arg in main.py
- 17 tests: parsing (13, incl. out-of-vocab regression) + sampling (4)

Raise ValueError when element index exceeds atom_types vocabulary
instead of silently dropping the bias entry.

Usage: python main.py --composition_bias "Fe:2.0,S:1.5,O:-1.0" ...

Ref: deepmodeling#13

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@exopoiesis
Copy link
Copy Markdown
Author

Moving composition_bias to crystalformer-x per discussion in #13. Thanks for the guidance!

@exopoiesis exopoiesis closed this Apr 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant