From ea6aef111c568d2b3a8e60f2807478467928b86f Mon Sep 17 00:00:00 2001 From: sweagent Date: Thu, 11 Apr 2024 03:20:38 +0000 Subject: [PATCH] Fix: Batched conditional generation Closes #25 --- tests/models/test_chroma.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/models/test_chroma.py b/tests/models/test_chroma.py index c7f0409..2864e13 100644 --- a/tests/models/test_chroma.py +++ b/tests/models/test_chroma.py @@ -63,8 +63,21 @@ def test_chroma(chroma): conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1), ], ) -def test_sample(chroma, conditioner): - chroma.sample(steps=3, conditioner=conditioner, design_method=None) +@pytest.mark.parametrize( + "conditioner", + [ + conditioners.Identity(), + conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1), + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +def test_sample(chroma, conditioner, batch_size): + # Generate a batch of proteins with the specified batch size + proteins = [Protein.from_CIF(PROTEIN_SAMPLE) for _ in range(batch_size)] + # Stack proteins into a batch + protein_batch = Protein.stack(proteins) + # Sample with the specified conditioner and batch of proteins + chroma.sample(steps=3, conditioner=conditioner, protein_batch=protein_batch, design_method=None) @pytest.mark.parametrize(