diff --git a/chroma/models/chroma.py b/chroma/models/chroma.py index 00d4802..7c9e3fd 100644 --- a/chroma/models/chroma.py +++ b/chroma/models/chroma.py @@ -109,6 +109,7 @@ def sample( sde_func: Literal["langevin", "reverse_sde", "ode"] = "reverse_sde", trajectory_length: int = 200, full_output: bool = False, + batch_size: int = 128, # Sidechain Args design_ban_S: Optional[List[str]] = None, design_method: Literal["potts", "autoregressive"] = "potts", @@ -163,6 +164,7 @@ def sample( Default is (1.0, 0.001). trajectory_length (int, optional): The number of sampled steps in the trajectory output. Maximum is `steps`. Default 200. + batch_size (int, optional): The batch size for sampling. Default 128. **kwargs: Additional keyword arguments for the integration function. Sequence and sidechain sampling: @@ -309,6 +311,7 @@ def _sample( sde_func: Literal["langevin", "reverse_sde", "ode"] = "reverse_sde", trajectory_length: int = 200, full_output: bool = False, + batch_size: int = 128, **kwargs, ) -> Union[ Tuple[List[Protein], List[Protein]], @@ -344,6 +347,7 @@ def _sample( Default is (1.0, 0.001). trajectory_length (int, optional): The number of sampled steps in the trajectory output. Maximum is `steps`. Default 200. + batch_size (int, optional): The batch size for sampling. Default 128. **kwargs: Additional keyword arguments for the integration function. Returns: @@ -355,23 +359,43 @@ def _sample( if protein_init is not None: X_unc, C_unc, S_unc = protein_init.to_XCS() + X_unc = X_unc.repeat(samples, 1, 1, 1) + C_unc = C_unc.repeat(samples, 1) + S_unc = S_unc.repeat(samples, 1) else: X_unc, C_unc, S_unc = self._init_backbones(samples, chain_lengths) - outs = self.backbone_network.sample_sde( - C_unc, - X_init=X_unc, - conditioner=conditioner, - tspan=tspan, - langevin_isothermal=langevin_isothermal, - integrate_func=integrate_func, - sde_func=sde_func, - langevin_factor=langevin_factor, - inverse_temperature=inverse_temperature, - N=steps, - initialize_noise=initialize_noise, - **kwargs, - ) + num_batches = X_unc.shape[0] // batch_size + if X_unc.shape[0] % batch_size != 0: + num_batches += 1 + + outs = { + "C": torch.tensor([], device=X_unc.device), + "X_sample": torch.tensor([], device=X_unc.device), + "X_trajectory": [torch.tensor([], device=X_unc.device) for i in range(steps)], + "Xhat_trajectory": [torch.tensor([], device=X_unc.device) for i in range(steps)], + "Xunc_trajectory": [torch.tensor([], device=X_unc.device) for i in range(steps)], + } + for b in range(num_batches): + outs_ = self.backbone_network.sample_sde( + C_unc[b * batch_size : (b + 1) * batch_size], + X_init=X_unc[b * batch_size : (b + 1) * batch_size], + conditioner=conditioner, + tspan=tspan, + langevin_isothermal=langevin_isothermal, + integrate_func=integrate_func, + sde_func=sde_func, + langevin_factor=langevin_factor, + inverse_temperature=inverse_temperature, + N=steps, + initialize_noise=initialize_noise, + **kwargs, + ) + outs["C"] = torch.cat([outs["C"], outs_["C"]], dim=0) + outs["X_sample"] = torch.cat([outs["X_sample"], outs_["X_sample"]], dim=0) + for key in ['X_trajectory', 'Xhat_trajectory', 'Xunc_trajectory']: + for i in range(steps): + outs[key][i] = torch.cat([outs[key][i], outs_[key][i]], dim=0) if S_unc.shape != outs["C"].shape: S = torch.zeros_like(outs["C"]).long()