diff --git a/tests/models/test_chroma.py b/tests/models/test_chroma.py index c7f0409..2864e13 100644 --- a/tests/models/test_chroma.py +++ b/tests/models/test_chroma.py @@ -63,8 +63,21 @@ def test_chroma(chroma): conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1), ], ) -def test_sample(chroma, conditioner): - chroma.sample(steps=3, conditioner=conditioner, design_method=None) +@pytest.mark.parametrize( + "conditioner", + [ + conditioners.Identity(), + conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1), + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +def test_sample(chroma, conditioner, batch_size): + # Generate a batch of proteins with the specified batch size + proteins = [Protein.from_CIF(PROTEIN_SAMPLE) for _ in range(batch_size)] + # Stack proteins into a batch + protein_batch = Protein.stack(proteins) + # Sample with the specified conditioner and batch of proteins + chroma.sample(steps=3, conditioner=conditioner, protein_batch=protein_batch, design_method=None) @pytest.mark.parametrize(