diff --git a/chroma/data/system.py b/chroma/data/system.py index 4d42c5e..2a33854 100644 --- a/chroma/data/system.py +++ b/chroma/data/system.py @@ -20,7 +20,7 @@ import warnings from dataclasses import dataclass from functools import partial -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import numpy as np import torch diff --git a/chroma/data/xcs.py b/chroma/data/xcs.py index 3883376..7ce063d 100644 --- a/chroma/data/xcs.py +++ b/chroma/data/xcs.py @@ -28,7 +28,7 @@ `C` (LongTensor), the chain map encoding per-residue chain assignments with shape `(num_batch, num_residues)`.The chain map codes positions as `0` - when masked, poitive integers for chain indices, and negative integers + when masked, positive integers for chain indices, and negative integers to represent missing residues (of the corresponding positive integers). `S` (LongTensor), the sequence of the protein as alphabet indices with diff --git a/chroma/layers/attention.py b/chroma/layers/attention.py index d673c38..12b0a8d 100644 --- a/chroma/layers/attention.py +++ b/chroma/layers/attention.py @@ -64,11 +64,11 @@ class MultiHeadAttention(nn.Module): for details and intuition. Args: - n_head (int): number of attention heads - d_k (int): dimension of the keys and queries in each attention head - d_v (int): dimension of the values in each attention head - d_model (int): input and output dimension for the layer - dropout (float): dropout rate, default is 0.1 + n_head (int): number of attention heads + d_k (int): dimension of the keys and queries in each attention head + d_v (int): dimension of the values in each attention head + d_model (int): input and output dimension for the layer + dropout (float): dropout rate, default is 0.1 Inputs: Q (torch.tensor): query tensor of shape ```(batch_size, sequence_length_q, d_model)``` diff --git a/chroma/layers/structure/protein_graph.py b/chroma/layers/structure/protein_graph.py index f2ff0a2..ba06344 100644 --- a/chroma/layers/structure/protein_graph.py +++ b/chroma/layers/structure/protein_graph.py @@ -101,7 +101,7 @@ class ProteinFeatureGraph(nn.Module): for the the third dimension are PDB order (`[N, CA, C, O]`). C (LongTensor, optional): Chain map with shape `(num_batch, num_residues)`. The chain map codes positions as `0` - when masked, poitive integers for chain indices, and negative + when masked, positive integers for chain indices, and negative integers to represent missing residues of the corresponding positive integers. custom_D (Tensor, optional): Pre-computed custom distance map diff --git a/chroma/models/graph_design.py b/chroma/models/graph_design.py index c74e93f..9b50e80 100644 --- a/chroma/models/graph_design.py +++ b/chroma/models/graph_design.py @@ -1954,8 +1954,12 @@ def sample( smoothing values less than 1.0 are recommended. top_p (float, optional): Top-p cutoff for Nucleus Sampling, see Holtzman et al ICLR 2020. - ban_S (tuple, optional): An optional set of token indices from - `chroma.constants.AA20` to ban during sampling. + mask_S (torch.Tensor, optional): Binary tensor mask indicating + masked/banned tokens during sampling at each residue with shape + `(num_batch, num_residues, num_alphabet)`. + bias (torch.Tensor, optional): Bias for each token for at + each residue added to log probabilities with shape + `(num_batch, num_residues, num_alphabet)`. Returns: S_sample (torch.LongTensor): Sampled sequence of shape `(num_batch,