diff --git a/docs/si_guide.md b/docs/si_guide.md new file mode 100644 index 00000000..9152edfd --- /dev/null +++ b/docs/si_guide.md @@ -0,0 +1,332 @@ +# RFdiffusion3 SI算法实现指南 + +本文档提供了论文补充材料(SI)第2节 "RFdiffusion3 Architecture and Inference" 中所有算法的代码实现位置映射。 + +## 目录结构 + +主要代码位于 `./models/rfd3/src/rfd3/model/` 目录下: +- `RFD3.py` - 主模型类 +- `RFD3_diffusion_module.py` - 扩散模块 +- `inference_sampler.py` - 推理采样器 +- `layers/` - 各种网络层实现 + - `encoders.py` - Token和Atom初始化器 + - `blocks.py` - Transformer块和其他构建模块 + - `attention.py` - 注意力机制 + - `pairformer_layers.py` - Pairformer层 + - `layer_utils.py` - 基础层工具 + +--- + +## 算法映射 + +### Section 2.1: Main Inference Loop + +#### **Algorithm 1: Inference diffusion loop** (SI p.56) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/inference_sampler.py` +- 类: `ConditionalDiffusionSampler` +- 方法: `sample_diffusion_like_af3()` +- 关键代码段: 行 ~200-300 + +**算法步骤说明:** +1. **Token and atom embedding (步骤1-2)**: 在 `RFD3.py` 的 `forward()` 方法中调用 `token_initializer` +2. **Initialize structure (步骤3-11)**: 在 `inference_sampler.py` 的主循环中 +3. **Modulation of ODE/SDE (步骤4-5)**: 噪声调制参数计算 +4. **Noise injection (步骤6-7)**: 向扩散坐标添加噪声 +5. **DiffusionModule (步骤8)**: 调用 `RFD3DiffusionModule` +6. **Coordinate update (步骤9-10)**: 更新原子坐标 + +--- + +### Section 2.2: Inference Loop for Symmetric Design + +#### **Algorithm 2: Symmetric inference diffusion loop** (SI p.57) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py` +- 函数: `apply_symmetry_to_xyz_atomwise()` +- 辅助文件: `./models/rfd3/src/rfd3/model/inference_sampler.py` +- 配置: `SampleDiffusionConfig.kind = "symmetry"` + +**关键特性:** +- 对称性应用 (步骤9-10): 使用旋转矩阵 `R={Rn}` 对非对称单元进行对称化 +- 可配置的对称停止比例 `σsym` (默认0.9) + +--- + +### Section 2.3: Token Initializers + +#### **Algorithm 3: Token initializer** (SI p.59) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/encoders.py` +- 类: `TokenInitializer` +- 方法: `forward()` + +**关键组件:** +- **步骤1-4**: `OneDFeatureEmbedder` 用于1D特征嵌入 +- **步骤5**: `Transition` 层 +- **步骤6-8**: 配对特征初始化 (`to_z_init_i`, `to_z_init_j`, `relative_position_encoding`) +- **步骤9**: `PositionPairDistEmbedder` 用于距离编码 +- **步骤10-12**: `PairformerBlock` 在 `transformer_stack` 中 +- **步骤13-19**: 配对特征后处理 + +#### **Algorithm 4: Atom initializer** (SI p.60) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/encoders.py` +- 类: `TokenInitializer` (包含atom初始化逻辑) +- 方法: `forward()` 的后半部分 + +**关键组件:** +- **步骤1**: `atom_1d_embedder_2` - Atom级1D特征嵌入 +- **步骤2**: 从token特征投影: `process_s_trunk` +- **步骤3**: `SinusoidalDistEmbed` - Motif位置嵌入 +- **步骤4**: `PositionPairDistEmbedder` - 参考位置嵌入 +- **步骤5-7**: Atom配对特征的MLP处理: `process_single_l`, `process_single_m`, `process_z`, `pair_mlp` + +--- + +### Section 2.4: DiffusionModule + +#### **Algorithm 5: Diffusion forward pass with recycling** (SI p.61) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/RFD3_diffusion_module.py` +- 类: `RFD3DiffusionModule` +- 方法: `forward()` + +**关键步骤:** +- **步骤1**: 坐标缩放: `process_r` +- **步骤2-3**: `Downcast` 池化到序列级别 +- **步骤4-7**: 时间/批次嵌入: `fourier_embedding`, `process_n` +- **步骤8**: Local atom transformer: `encoder` (LocalAtomTransformer) +- **步骤9**: Downcast: `downcast_q` +- **步骤10-17**: 循环处理 (nrecycle=2): + - **步骤12**: `DiffusionTokenEncoder` - 嵌入噪声尺度和循环distogram + - **步骤13**: `transformer` (LocalTokenTransformer) - Token级稀疏注意力 + - **步骤14**: `decoder` (CompactStreamingDecoder) - 上投影并解码为结构 + - **步骤15-16**: 坐标更新: `to_r_update` + +--- + +### Section 2.4.1 & 2.4.2: Transformers and Attention + +#### **Algorithm 6: Local token transformer** (SI p.64) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/blocks.py` +- 类: `LocalTokenTransformer` +- 方法: `forward()` + +**关键特性:** +- **步骤1**: `create_attention_indices()` - 创建SL2稀疏注意力索引 +- **步骤2-8**: 循环通过多个transformer块 +- **步骤4**: `Upcast` - 如果提供了cskip +- **步骤6**: `SparseAttentionPairBias` - 带配对偏置的稀疏注意力 +- **步骤7**: `ConditionedTransitionBlock` - 条件化的transition块 + +#### **Algorithm 7: TransformerBlock** (SI p.64) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/pairformer_layers.py` +- 类: `PairformerBlock` +- 方法: `forward()` + +**特性:** +- 全注意力版本的transformer +- 用于处理配对特征到单轨迹 + +#### **Algorithm 8: Sparse attention with pair bias** (SI p.65) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/attention.py` +- 类: `LocalAttentionPairBias` +- 方法: `forward()` + +**关键步骤:** +- **步骤1-5**: AdaLN或RMSNorm标准化 +- **步骤6-9**: Q, K, V, gate投影 +- **步骤10**: `sparse_pairbias_attention()` - 带配对偏置的稀疏注意力 +- **步骤12-14**: 可选的门控与条件信号 + +--- + +### Section 2.5: Cross Attention Pooling + +#### **Algorithm 9: Downcast** (SI p.66) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/blocks.py` +- 类: `Downcast` +- 方法: `forward()` + +**关键操作:** +- **步骤1**: `group_atoms()` - 按token ID分组原子 +- **步骤2**: `GatedCrossAttention` - Q=ai, KV=qia +- **步骤4**: 添加token特征 (si) + +#### **Algorithm 10: Upcast** (SI p.66) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/blocks.py` +- 类: `Upcast` (在 `CompactStreamingDecoder` 中使用) +- 相关代码在多个transformer块中 + +**关键操作:** +- **步骤1**: `reshape()` - 分割token并按token ID分组原子 +- **步骤3**: `GatedCrossAttention` - Q=qia, KV=a_split + +#### **Algorithm 11: Gated cross attention** (SI p.67) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/attention.py` +- 类: `GatedCrossAttention` +- 方法: `forward()` + +**关键步骤:** +- **步骤1-2**: RMSNorm标准化 +- **步骤3-6**: Q, K, V, gate投影 +- **步骤7-13**: Scaled dot-product attention with valid masking +- **步骤14**: 合并头部并投影 + +--- + +### Section 2.6: Embedders + +#### **Algorithm 12: Diffusion token encoder** (SI p.68) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/encoders.py` +- 类: `DiffusionTokenEncoder` +- 方法: `forward()` + +**关键步骤:** +- **步骤1-3**: Transition层处理 +- **步骤4**: `bucketize_scaled_distogram()` - 将noisy和self坐标的distogram离散化 +- **步骤9-11**: `PairformerBlock` 循环 + +#### **Algorithm 13: Sinusoidal Distance Embedding** (SI p.69) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/blocks.py` +- 类: `SinusoidalDistEmbed` +- 方法: `forward()` + +**关键步骤:** +- **步骤1**: 计算配对距离 +- **步骤2-4**: 正弦嵌入: `ω_k`, `θ_lmk`, `e_lm = [sin(θ) || cos(θ)]` +- **步骤5-7**: 应用并嵌入mask + +#### **Algorithm 14: Position Pair Distance Embedding** (SI p.69) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/blocks.py` +- 类: `PositionPairDistEmbedder` +- 方法: `forward()` + +**关键步骤:** +- **步骤1**: 计算配对距离 +- **步骤2**: 编码逆配对距离: `1/(1 + ||d_lm||^2)` +- **步骤3-4**: 嵌入mask + +#### **Algorithm 15: One-dimension Feature Embedder** (SI p.69) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/layers/blocks.py` +- 类: `OneDFeatureEmbedder` +- 方法: `forward()` + +**特性:** +- 对token级和atom级特征使用相同的实现 +- 对每个1D特征嵌入并求和结果 + +--- + +### Section 2.7: Classifier Free Guidance + +#### **Algorithm 16: FourierEmbedding** (SI p.70) + +**实现位置:** +- 文件: `./src/foundry/model/layers/blocks.py` (foundry包) +- 类: `FourierEmbedding` + +**用法:** 在 `RFD3_diffusion_module.py` 中的时间嵌入 + +#### **Algorithm 17: Processing of noise conditioning features** (SI p.70) + +**实现位置:** +- 文件: `./models/rfd3/src/rfd3/model/RFD3_diffusion_module.py` +- 在 `forward()` 方法中的时间处理部分 + +**关键步骤:** +- **步骤1**: `fourier_embedding` - Fourier特征 +- **步骤2**: 乘以时间mask `(tl > 0)` + +**分类器自由引导实现:** +- 文件: `./models/rfd3/src/rfd3/model/cfg_utils.py` +- 函数: `strip_f()`, `strip_X()` +- 配置: `use_classifier_free_guidance=True`, `cfg_scale` + +--- + +## 其他重要文件 + +### 训练相关 +- `./models/rfd3/src/rfd3/trainer/` - 训练器实现 +- `./models/rfd3/src/rfd3/transforms/` - 数据转换和条件化 + +### 推理相关 +- `./models/rfd3/src/rfd3/run_inference.py` - 推理脚本 +- `./models/rfd3/src/rfd3/utils/inference.py` - 推理工具函数 + +### 对称性支持 +- `./models/rfd3/src/rfd3/inference/symmetry/` - 对称性推理 +- `./models/rfd3/src/rfd3/transforms/symmetry.py` - 对称性转换 + +### 配置 +- `./models/rfd3/configs/` - Hydra配置文件 + +--- + +## 注意事项 + +1. **坐标表示**: 所有坐标使用atom级别表示 `[B, L, 3]`, 其中L是原子数 +2. **Token vs Atom**: + - Token (I): 残基或小分子作为单个token + - Atom (L): 每个原子独立表示 +3. **稀疏注意力 (SL2)**: + - 序列局部: `n_attn_seq_neighbours = 32` + - 结构局部: `n_attn_keys = 128` +4. **循环 (Recycling)**: 默认 `n_recycle = 2` +5. **低内存模式**: 设置 `RFD3_LOW_MEMORY_MODE=1` 使用分块的配对特征处理 + +--- + +## 快速索引 + +| 算法 | 主要实现类/函数 | 文件路径 | +|------|----------------|----------| +| Algorithm 1 | `ConditionalDiffusionSampler.sample_diffusion_like_af3()` | `inference_sampler.py` | +| Algorithm 2 | `apply_symmetry_to_xyz_atomwise()` | `inference/symmetry/symmetry_utils.py` | +| Algorithm 3 | `TokenInitializer` | `layers/encoders.py` | +| Algorithm 4 | `TokenInitializer` (atom部分) | `layers/encoders.py` | +| Algorithm 5 | `RFD3DiffusionModule.forward()` | `RFD3_diffusion_module.py` | +| Algorithm 6 | `LocalTokenTransformer` | `layers/blocks.py` | +| Algorithm 7 | `PairformerBlock` | `layers/pairformer_layers.py` | +| Algorithm 8 | `LocalAttentionPairBias` | `layers/attention.py` | +| Algorithm 9 | `Downcast` | `layers/blocks.py` | +| Algorithm 10 | `Upcast` in `CompactStreamingDecoder` | `layers/blocks.py` | +| Algorithm 11 | `GatedCrossAttention` | `layers/attention.py` | +| Algorithm 12 | `DiffusionTokenEncoder` | `layers/encoders.py` | +| Algorithm 13 | `SinusoidalDistEmbed` | `layers/blocks.py` | +| Algorithm 14 | `PositionPairDistEmbedder` | `layers/blocks.py` | +| Algorithm 15 | `OneDFeatureEmbedder` | `layers/blocks.py` | +| Algorithm 16 | `FourierEmbedding` | `foundry/model/layers/blocks.py` | +| Algorithm 17 | Time processing in `RFD3DiffusionModule` | `RFD3_diffusion_module.py` | + +--- + +**生成日期**: 2025-12-23 +**基于论文**: RFdiffusion3 Supplementary Information, Section 2 diff --git a/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py b/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py index 3730695e..0ee0ce69 100644 --- a/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py +++ b/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py @@ -31,6 +31,38 @@ class RFD3DiffusionModule(nn.Module): + """ + RFDiffusion3 扩散模块 (Algorithm 5: Diffusion forward pass with recycling) + RFDiffusion3 Diffusion Module + + 这是RFD3的核心扩散模块,实现了SI中的Algorithm 5。 + 采用UNet风格的架构,在token和atom两个层级处理特征。 + This is the core diffusion module of RFD3, implementing Algorithm 5 from SI. + Uses a UNet-style architecture for processing features across tokens and atoms. + + 主要组件 / Main Components: + - Encoder: 局部atom transformer用于atom级特征编码 + - Encoder: Local atom transformer for atom-level feature encoding + - Diffusion Token Encoder: 嵌入噪声尺度和循环distogram + - Diffusion Token Encoder: Embeds noise scale and recycled distogram + - Diffusion Transformer: 在token级进行稀疏注意力 + - Diffusion Transformer: Sparse attention at token level + - Decoder: 解码为结构更新 + - Decoder: Decodes to structure updates + + 参数 / Parameters: + c_atom: Atom级别特征维度 / Atom-level feature dimension + c_atompair: Atom配对特征维度 / Atom pairwise feature dimension + c_token: Token级别特征维度 / Token-level feature dimension + c_s: 单轨迹特征维度 / Single track feature dimension + c_z: 配对轨迹特征维度 / Pair track feature dimension + c_t_embed: 时间嵌入维度 / Time embedding dimension + sigma_data: EDM数据方差 / EDM data variance (default: 16) + f_pred: 预测类型 ("edm", "unconditioned", "noise_pred") / Prediction type + n_attn_seq_neighbours: 序列局部注意力邻居数 (默认32) / Sequence-local attention neighbors + n_attn_keys: 结构局部注意力键数 (默认128) / Structure-local attention keys + n_recycle: 循环次数 (默认2) / Number of recycling iterations + """ def __init__( self, *, @@ -66,41 +98,57 @@ def __init__( self.n_attn_keys = n_attn_keys self.use_local_token_attention = use_local_token_attention - # Auxiliary - self.process_r = linearNoBias(3, c_atom) + # ===== 辅助模块 / Auxiliary Modules ===== + # Algorithm 5 步骤1: 坐标缩放 / Step 1: Scale positions + self.process_r = linearNoBias(3, c_atom) # [3] -> [c_atom] + # Algorithm 5 步骤15: 坐标更新投影 / Step 15: Coordinate update projection self.to_r_update = nn.Sequential(RMSNorm((c_atom,)), linearNoBias(c_atom, 3)) + # 序列预测头 / Sequence prediction head self.sequence_head = LinearSequenceHead(c_token=c_token) - self.n_recycle = n_recycle - self.n_bins = 65 + # 循环和distogram参数 / Recycling and distogram parameters + self.n_recycle = n_recycle # 默认2次循环 / Default 2 recycling iterations + self.n_bins = 65 # Distogram的离散化bin数 / Number of distogram bins self.bucketize_fn = functools.partial( bucketize_scaled_distogram, - min_dist=1, - max_dist=30, + min_dist=1, # 最小距离1Å / Minimum distance 1Å + max_dist=30, # 最大距离30Å / Maximum distance 30Å sigma_data=1, n_bins=self.n_bins, ) - # Time processing + # ===== 时间处理 (Algorithm 16 & 17) / Time Processing ===== + # Algorithm 16: FourierEmbedding - 用于时间编码 + # Algorithm 16: FourierEmbedding - for time encoding self.fourier_embedding = nn.ModuleList( - [FourierEmbedding(c_t_embed), FourierEmbedding(c_t_embed)] + [FourierEmbedding(c_t_embed), FourierEmbedding(c_t_embed)] # [0]: atom级, [1]: token级 ) + # Algorithm 17: Processing of noise conditioning features + # 步骤1: 处理Fourier特征并投影 / Step 1: Process Fourier features and project self.process_n = nn.ModuleList( [ - nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_atom)), - nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)), + nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_atom)), # [0]: 用于atom + nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)), # [1]: 用于token ] ) + # ===== Algorithm 9: Downcast (池化操作) / Pooling Operations ===== + # 将atom特征池化到token特征 / Pool atom features to token features self.downcast_c = Downcast(c_atom=c_atom, c_token=c_s, c_s=None, **downcast) self.downcast_q = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast) - self.process_a = LinearEmbedWithPool(c_token) - self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom)) + self.process_a = LinearEmbedWithPool(c_token) # Algorithm 5 步骤2 + self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom)) # 步骤7 - # UNet-like architecture for processing across tokens and atoms + # ===== UNet风格架构 / UNet-style Architecture ===== + # 在token和atom两个层级处理特征 / Process features across tokens and atoms + + # Algorithm 5 步骤8: Local-atom transformer (编码器) + # Algorithm 5 Step 8: Local-atom transformer (encoder) self.encoder = LocalAtomTransformer( c_atom=c_atom, c_s=c_atom, c_atompair=c_atompair, **atom_attention_encoder ) + # Algorithm 5 步骤12: Diffusion token encoder + # 嵌入噪声尺度和循环distogram / Embed noise scale and recycled distogram self.diffusion_token_encoder = DiffusionTokenEncoder( c_s=c_s, c_token=c_token, @@ -109,6 +157,8 @@ def __init__( **diffusion_token_encoder, ) + # Algorithm 5 步骤13: Sparse attention at token level + # Token级稀疏注意力 / Token-level sparse attention self.diffusion_transformer = LocalTokenTransformer( c_token=c_token, c_tokenpair=c_z, @@ -116,6 +166,8 @@ def __init__( **diffusion_transformer, ) + # Algorithm 5 步骤14: Up-projection and decode to structure + # 上投影并解码为结构 / Up-projection and decode to structure self.decoder = CompactStreamingDecoder( c_atom=c_atom, c_atompair=c_atompair, @@ -126,16 +178,34 @@ def __init__( ) def scale_positions_in(self, X_noisy_L: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + 输入坐标缩放 (EDM预条件化) / Input coordinate scaling (EDM preconditioning) + + 根据EDM框架对噪声坐标进行缩放,使其适合神经网络输入。 + Scale noisy coordinates according to EDM framework for neural network input. + + 参数 / Args: + X_noisy_L: [B, L, 3] 噪声坐标 / Noisy coordinates + t: [B] 或 [B, L] 噪声水平 / Noise level + + 返回 / Returns: + R_noisy_L: [B, L, 3] 缩放后的坐标 / Scaled coordinates + """ + # 扩展时间维度以匹配坐标形状 / Expand time dimension to match coordinate shape if t.ndim == 1: - t = t[..., None, None] # [B, (n_atoms), (3)] + t = t[..., None, None] # [B] -> [B, 1, 1] elif t.ndim == 2: - t = t[..., None] # [B, n_atoms, (3)] + t = t[..., None] # [B, L] -> [B, L, 1] + # EDM预条件化公式: c_in(σ) * X / Preconditioning formula if self.f_pred == "edm": + # c_in(σ) = 1 / sqrt(σ^2 + σ_data^2) R_noisy_L = X_noisy_L / torch.sqrt(t**2 + self.sigma_data**2) elif self.f_pred == "unconditioned": + # 无条件:清零输入 / Unconditional: zero input R_noisy_L = torch.zeros_like(X_noisy_L) elif self.f_pred == "noise_pred": + # 噪声预测:直接使用噪声坐标 / Noise prediction: use noisy coordinates directly R_noisy_L = X_noisy_L else: raise Exception(f"{self.f_pred=} unrecognized") @@ -144,31 +214,69 @@ def scale_positions_in(self, X_noisy_L: torch.Tensor, t: torch.Tensor) -> torch. def scale_positions_out( self, R_update_L: torch.Tensor, X_noisy_L: torch.Tensor, t: torch.Tensor ) -> torch.Tensor: + """ + 输出坐标缩放 (EDM去预条件化) / Output coordinate scaling (EDM de-preconditioning) + + 将神经网络输出转换为去噪后的坐标,遵循EDM框架。 + Convert neural network output to denoised coordinates following EDM framework. + + 参数 / Args: + R_update_L: [B, L, 3] 网络预测的更新 / Network predicted update + X_noisy_L: [B, L, 3] 输入的噪声坐标 / Input noisy coordinates + t: [B] 或 [B, L] 噪声水平 / Noise level + + 返回 / Returns: + X_out_L: [B, L, 3] 去噪后的坐标 / Denoised coordinates + """ + # 扩展时间维度 / Expand time dimension if t.ndim == 1: - t = t[..., None, None] + t = t[..., None, None] # [B] -> [B, 1, 1] elif t.ndim == 2: - t = t[..., None] # [B, n_atoms, (3)] + t = t[..., None] # [B, L] -> [B, L, 1] + # EDM去预条件化公式 / EDM de-preconditioning formula if self.f_pred == "edm": + # c_skip(σ) * X + c_out(σ) * F_θ + # c_skip = σ_data^2 / (σ^2 + σ_data^2) + # c_out = σ * σ_data / sqrt(σ^2 + σ_data^2) X_out_L = (self.sigma_data**2 / (self.sigma_data**2 + t**2)) * X_noisy_L + ( self.sigma_data * t / (self.sigma_data**2 + t**2) ** 0.5 ) * R_update_L elif self.f_pred == "unconditioned": + # 无条件:直接使用预测 / Unconditional: use prediction directly X_out_L = R_update_L elif self.f_pred == "noise_pred": + # 噪声预测: X - ε / Noise prediction: X - ε X_out_L = X_noisy_L + R_update_L else: raise Exception(f"{self.f_pred=} unrecognized") return X_out_L def process_time_(self, t_L: torch.Tensor, i: int) -> torch.Tensor: + """ + 时间条件化特征处理 (Algorithm 17) / Time conditioning feature processing + + 将噪声水平 t 编码为特征,用于条件化扩散过程。 + Encode noise level t into features for conditioning the diffusion process. + + 参数 / Args: + t_L: [B, L] 每个atom/token的噪声水平 / Noise level per atom/token + i: 0=atom级, 1=token级 / 0=atom-level, 1=token-level + + 返回 / Returns: + C_L: [B, L, c_atom] 或 [B, L, c_s] 时间条件化特征 / Time conditioning features + """ + # Algorithm 17 步骤1: 对数时间编码 + Fourier特征 + # Log-time encoding + Fourier features + # log(t/σ_data) / 4 转换为频率范围 C_L = self.process_n[i]( self.fourier_embedding[i]( 1 / 4 * torch.log(torch.clamp(t_L, min=1e-20) / self.sigma_data) ) ) - # Mask out zero-time features; - C_L = C_L * (t_L > 0).float()[..., None] # [B, L, C_atom] + # Algorithm 17 步骤2: 屏蔽零时间(固定motif区域) + # Mask out zero-time features (fixed motif regions) + C_L = C_L * (t_L > 0).float()[..., None] # [B, L, c_atom/c_s] return C_L def forward( @@ -189,62 +297,89 @@ def forward( **kwargs: Any, ) -> Dict[str, torch.Tensor]: """ - Diffusion forward pass with recycling. + 扩散前向传播 (Algorithm 5: Diffusion forward pass with recycling) + Diffusion forward pass with recycling + + 给定噪声坐标和编码特征,计算去噪后的位置。 Computes denoised positions given encoded features and noisy coordinates. + + 参数 / Args: + X_noisy_L: [B, L, 3] 噪声坐标 / Noisy coordinates + t: [B] 噪声水平 / Noise level + f: 特征字典 / Feature dictionary + Q_L_init: [L, c_atom] 初始atom特征 / Initial atom features (from TokenInitializer) + C_L: [L, c_atom] 条件化atom特征 / Conditioned atom features + P_LL: [L, L, c_atompair] Atom配对特征 / Atom pairwise features + S_I: [I, c_s] Token单轨迹特征 / Token single features + Z_II: [I, I, c_z] Token配对特征 / Token pair features + n_recycle: 循环次数 / Number of recycling iterations + + 返回 / Returns: + outputs: 包含去噪坐标和序列预测的字典 / Dictionary with denoised coordinates and sequence predictions """ - # ... Collect inputs - tok_idx = f["atom_to_token_map"] - L = len(tok_idx) - I = tok_idx.max() + 1 # Number of tokens + # ===== 步骤1: 收集输入和创建注意力索引 / Step 1: Collect inputs and create attention indices ===== + tok_idx = f["atom_to_token_map"] # [L] atom到token的映射 / Atom to token mapping + L = len(tok_idx) # 原子总数 / Total number of atoms + I = tok_idx.max() + 1 # Token总数 / Total number of tokens + + # 创建SL2稀疏注意力索引 (序列局部 + 结构局部) + # Create SL2 sparse attention indices (sequence-local + structure-local) f["attn_indices"] = create_attention_indices( X_L=X_noisy_L, f=f, - n_attn_keys=self.n_attn_keys, - n_attn_seq_neighbours=self.n_attn_seq_neighbours, + n_attn_keys=self.n_attn_keys, # 结构局部键数 (默认128) + n_attn_seq_neighbours=self.n_attn_seq_neighbours, # 序列局部邻居 (默认32) ) - # ... Expand t tensors + # ===== 步骤2-3: 扩展时间张量并屏蔽固定区域 / Step 2-3: Expand time tensors and mask fixed regions ===== + # t_L: [B, L] 每个atom的噪声水平,motif区域为0 t_L = t.unsqueeze(-1).expand(-1, L) * ( ~f["is_motif_atom_with_fixed_coord"] ).float().unsqueeze(0) + # t_I: [B, I] 每个token的噪声水平,完全固定的token为0 t_I = t.unsqueeze(-1).expand(-1, I) * ( ~f["is_motif_token_with_fully_fixed_coord"] ).float().unsqueeze(0) - # ... Create scaled positions - R_L_uniform = self.scale_positions_in(X_noisy_L, t) - R_noisy_L = self.scale_positions_in(X_noisy_L, t_L) + # ===== 步骤4: 坐标缩放 (EDM预条件化) / Step 4: Scale positions (EDM preconditioning) ===== + R_L_uniform = self.scale_positions_in(X_noisy_L, t) # [B, L, 3] 均匀缩放用于distogram + R_noisy_L = self.scale_positions_in(X_noisy_L, t_L) # [B, L, 3] 每atom缩放用于特征 - # ... Pool initial representation to sequence level - A_I = self.process_a(R_noisy_L, tok_idx=tok_idx) - S_I = self.downcast_c(C_L, S_I, tok_idx=tok_idx) + # ===== 步骤5: 池化初始表示到token级 (Algorithm 9: Downcast) / Step 5: Pool initial representation to token level ===== + A_I = self.process_a(R_noisy_L, tok_idx=tok_idx) # [B, I, c_token] 从坐标池化的token特征 + S_I = self.downcast_c(C_L, S_I, tok_idx=tok_idx) # [I, c_s] 从atom特征池化的token特征 - # ... Add batch-wise features to inputs - Q_L = Q_L_init.unsqueeze(0) + self.process_r(R_noisy_L) - C_L = C_L.unsqueeze(0) + self.process_time_(t_L, i=0) - S_I = S_I.unsqueeze(0) + self.process_time_(t_I, i=1) - C_L = C_L + self.process_c(C_L) + # ===== 步骤6-7: 添加批次级特征 (时间条件化) / Step 6-7: Add batch-wise features (time conditioning) ===== + # Algorithm 5 步骤1: 坐标投影 + 初始化特征 + Q_L = Q_L_init.unsqueeze(0) + self.process_r(R_noisy_L) # [B, L, c_atom] + # Algorithm 17: 添加时间条件化特征 + C_L = C_L.unsqueeze(0) + self.process_time_(t_L, i=0) # [B, L, c_atom] atom级 + S_I = S_I.unsqueeze(0) + self.process_time_(t_I, i=1) # [B, I, c_s] token级 + C_L = C_L + self.process_c(C_L) # [B, L, c_atom] 额外的MLP处理 - # ... Run Local-Atom Self Attention and Pool + # ===== 步骤8: Local-Atom Self Attention (编码器) / Step 8: Local-Atom Self Attention (encoder) ===== + # Algorithm 5 步骤8: 局部atom transformer if chunked_pairwise_embedder is not None: - # Chunked mode: pass chunked embedder and feature dict + # 分块模式:传递分块嵌入器用于内存优化 / Chunked mode: pass chunked embedder for memory optimization Q_L = self.encoder( Q_L, C_L, P_LL=None, indices=f["attn_indices"], - f=f, # Pass feature dict for chunked computation + f=f, # 传递特征字典用于分块计算 chunked_pairwise_embedder=chunked_pairwise_embedder, initializer_outputs=initializer_outputs, ) else: - # Standard mode: use full P_LL + # 标准模式:使用完整的P_LL / Standard mode: use full P_LL Q_L = self.encoder(Q_L, C_L, P_LL, indices=f["attn_indices"]) - A_I = self.downcast_q(Q_L, A_I=A_I, S_I=S_I, tok_idx=tok_idx) - # Debug chunked parameters + # ===== 步骤9: 池化到token级准备transformer / Step 9: Pool to token level for transformer ===== + # Algorithm 9: Downcast - 将atom特征池化为token特征 + A_I = self.downcast_q(Q_L, A_I=A_I, S_I=S_I, tok_idx=tok_idx) # [B, I, c_token] - # ... Run forward with recycling + # ===== 步骤10-17: 循环处理 (Recycling Loop) / Step 10-17: Recycling loop ===== + # Algorithm 5 步骤10-17: 带distogram循环的迭代细化 recycled_features = self.forward_with_recycle( n_recycle, X_noisy_L=X_noisy_L, @@ -261,11 +396,11 @@ def forward( initializer_outputs=initializer_outputs, ) - # ... Collect outputs + # ===== 收集输出 / Collect outputs ===== outputs = { - "X_L": recycled_features["X_L"], # [B, L, 3] denoised positions - "sequence_indices_I": recycled_features["sequence_indices_I"], - "sequence_logits_I": recycled_features["sequence_logits_I"], + "X_L": recycled_features["X_L"], # [B, L, 3] 去噪后的坐标 / Denoised positions + "sequence_indices_I": recycled_features["sequence_indices_I"], # 序列索引 / Sequence indices + "sequence_logits_I": recycled_features["sequence_logits_I"], # 序列logits / Sequence logits } return outputs @@ -274,27 +409,46 @@ def forward_with_recycle( n_recycle: Optional[int], **kwargs: Any, ) -> Dict[str, torch.Tensor]: + """ + 循环前向传播包装器 (Algorithm 5 步骤10-17) / Recycling forward pass wrapper + + 迭代细化结构预测,使用前一次循环的distogram作为条件。 + Iteratively refines structure prediction using previous cycle's distogram as conditioning. + + 参数 / Args: + n_recycle: 循环次数,训练时必须提供,推理时使用self.n_recycle (默认2) + Number of recycling iterations, must be provided during training + + 返回 / Returns: + recycled_features: 最终循环的输出 / Output from final recycling iteration + """ + # 推理时使用默认循环次数,训练时必须明确指定 + # Use default n_recycle during inference, must be explicit during training if not self.training: n_recycle = self.n_recycle else: assert n_recycle is not None - recycled_features = {} + recycled_features = {} # 存储上一次循环的输出 / Store previous cycle outputs for i in range(n_recycle): with ExitStack() as stack: + # 只在最后一次循环保留梯度 / Only keep gradients in final cycle last = not (i < n_recycle - 1) if not last: + # 中间循环使用no_grad节省内存 / Use no_grad for intermediate cycles to save memory stack.enter_context(torch.no_grad()) - # Clear the autocast cache if gradients are enabled (workaround for autocast bug) + # 清除autocast缓存(PyTorch bug的解决方法) + # Clear autocast cache (workaround for PyTorch bug) # See: https://github.com/pytorch/pytorch/issues/65766 if torch.is_grad_enabled(): torch.clear_autocast_cache() - # Run forward + # 运行单次循环迭代 / Run single recycling iteration + # 将前一次的distogram (D_II_self) 和坐标 (X_L) 作为条件 recycled_features = self.process_( - D_II_self=recycled_features.get("D_II_self"), - X_L_self=recycled_features.get("X_L"), + D_II_self=recycled_features.get("D_II_self"), # [B, I, I, n_bins] 或 None + X_L_self=recycled_features.get("X_L"), # [B, L, 3] 或 None **kwargs, ) @@ -319,72 +473,109 @@ def process_( initializer_outputs: Optional[Dict[str, Any]] = None, **_: Any, ) -> Dict[str, torch.Tensor]: - # ... Embed token level features with atom level encodings + """ + 单次循环迭代处理 (Algorithm 5 步骤12-16) / Single recycling iteration processing + + 执行一次完整的transformer前向传播:编码器 -> transformer -> 解码器 -> 更新。 + Performs one complete transformer forward pass: encoder -> transformer -> decoder -> update. + + 参数 / Args: + D_II_self: [B, I, I, n_bins] 或 None - 前一次循环的distogram / Previous cycle's distogram + X_L_self: [B, L, 3] 或 None - 前一次循环的坐标 / Previous cycle's coordinates + R_L_uniform: [B, L, 3] 均匀缩放的坐标 / Uniformly scaled coordinates + X_noisy_L: [B, L, 3] 原始噪声坐标 / Original noisy coordinates + t_L: [B, L] 时间条件 / Time conditioning + f: 特征字典 / Feature dictionary + Q_L: [B, L, c_atom] Atom查询特征 / Atom query features + C_L: [B, L, c_atom] Atom条件特征 / Atom conditioning features + P_LL: [L, L, c_atompair] Atom配对特征 / Atom pairwise features + A_I: [B, I, c_token] Token特征 / Token features + S_I: [B, I, c_s] Token单轨迹特征 / Token single features + Z_II: [I, I, c_z] Token配对特征 / Token pair features + + 返回 / Returns: + 包含更新坐标、distogram和序列预测的字典 / Dictionary with updated coordinates, distogram, and sequence predictions + """ + # ===== 步骤12: DiffusionTokenEncoder - 嵌入噪声尺度和循环distogram ===== + # Step 12: DiffusionTokenEncoder - Embed noise scale and recycled distogram + # Algorithm 12: 将当前坐标的distogram和前一次循环的distogram嵌入到Z_II中 S_I, Z_II = self.diffusion_token_encoder( f=f, - R_L=R_L_uniform, - D_II_self=D_II_self, - S_init_I=S_I, - Z_init_II=Z_II, - C_L=C_L, - P_LL=P_LL, + R_L=R_L_uniform, # [B, L, 3] 用于计算当前distogram + D_II_self=D_II_self, # [B, I, I, n_bins] 前一次循环的self-conditioning distogram + S_init_I=S_I, # [B, I, c_s] 初始token特征 + Z_init_II=Z_II, # [I, I, c_z] 初始token配对特征 + C_L=C_L, # [B, L, c_atom] Atom条件特征 + P_LL=P_LL, # [L, L, c_atompair] Atom配对特征 ) - # ... Diffusion transformer + # ===== 步骤13: DiffusionTransformer - Token级稀疏注意力 ===== + # Step 13: DiffusionTransformer - Token-level sparse attention + # Algorithm 6: LocalTokenTransformer with SL2 sparse attention A_I = self.diffusion_transformer( - A_I, - S_I, - Z_II, + A_I, # [B, I, c_token] Token特征 + S_I, # [B, I, c_s] 单轨迹特征 + Z_II, # [B, I, I, c_z] 配对特征(用作注意力偏置) f=f, X_L=( + # 使用当前坐标或前一次循环的坐标(仅C-alpha)用于结构局部注意力 X_noisy_L[..., f["is_ca"], :] if X_L_self is None else X_L_self[..., f["is_ca"], :] ), - full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"), + full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"), # 低内存模式标志 ) - # ... Decoder readout - # Check if using chunked P_LL mode + # ===== 步骤14: Decoder - 上投影并解码为结构 ===== + # Step 14: Decoder - Up-projection and decode to structure + # CompactStreamingDecoder: Token -> Atom特征,包含Algorithm 10 (Upcast) if chunked_pairwise_embedder is not None: - # Chunked mode: pass embedder and no P_LL + # 分块模式:传递嵌入器,不使用预计算的P_LL / Chunked mode: pass embedder, no pre-computed P_LL A_I, Q_L, o = self.decoder( A_I, S_I, Z_II, Q_L, C_L, - P_LL=None, # Not used in chunked mode + P_LL=None, # 分块模式不使用 / Not used in chunked mode tok_idx=f["atom_to_token_map"], indices=f["attn_indices"], - f=f, # Pass f for chunked computation + f=f, # 传递特征字典用于按需计算 / Pass f for on-demand computation chunked_pairwise_embedder=chunked_pairwise_embedder, initializer_outputs=initializer_outputs, ) else: - # Original mode: use full P_LL + # 标准模式:使用完整的P_LL / Standard mode: use full P_LL A_I, Q_L, o = self.decoder( A_I, S_I, Z_II, Q_L, C_L, - P_LL=P_LL, + P_LL=P_LL, # [L, L, c_atompair] 预计算的atom配对特征 tok_idx=f["atom_to_token_map"], indices=f["attn_indices"], ) - # ... Process outputs to positions update - R_update_L = self.to_r_update(Q_L) - X_out_L = self.scale_positions_out(R_update_L, X_noisy_L, t_L) + # ===== 步骤15-16: 坐标更新和去预条件化 ===== + # Step 15-16: Coordinate update and de-preconditioning + # Algorithm 5 步骤15: 投影atom特征到3D坐标更新 + R_update_L = self.to_r_update(Q_L) # [B, L, 3] 预测的坐标更新 + # 步骤16: EDM去预条件化,得到去噪后的坐标 + X_out_L = self.scale_positions_out(R_update_L, X_noisy_L, t_L) # [B, L, 3] + # ===== 辅助输出:序列预测和distogram ===== + # Auxiliary outputs: sequence prediction and distogram + # 序列预测头:从token特征预测残基类型 sequence_logits_I, sequence_indices_I = self.sequence_head(A_I=A_I) - D_II_self = self.bucketize_fn(X_out_L[..., f["is_ca"], :].detach()) + # 计算distogram用于下一次循环的self-conditioning + # 使用detach()防止梯度回传到前一次循环 + D_II_self = self.bucketize_fn(X_out_L[..., f["is_ca"], :].detach()) # [B, I, I, n_bins] return { - "X_L": X_out_L, - "D_II_self": D_II_self, - "sequence_logits_I": sequence_logits_I, - "sequence_indices_I": sequence_indices_I, - } | o + "X_L": X_out_L, # [B, L, 3] 更新后的坐标 + "D_II_self": D_II_self, # [B, I, I, n_bins] 用于下一次循环的distogram + "sequence_logits_I": sequence_logits_I, # [B, I, 21] 序列预测logits + "sequence_indices_I": sequence_indices_I, # [B, I] 序列索引 + } | o # 合并解码器的额外输出 diff --git a/models/rfd3/src/rfd3/model/layers/blocks.py b/models/rfd3/src/rfd3/model/layers/blocks.py index eaf08093..de13419c 100644 --- a/models/rfd3/src/rfd3/model/layers/blocks.py +++ b/models/rfd3/src/rfd3/model/layers/blocks.py @@ -63,13 +63,33 @@ def forward( class PositionPairDistEmbedder(nn.Module): + """ + 位置配对距离嵌入器 (Algorithm 14: Position Pair Distance Embedding) + Position Pair Distance Embedder + + 将原子对的参考位置编码为配对特征。使用逆配对距离编码空间关系。 + Encodes reference positions of atom pairs into pairwise features. + Uses inverse pairwise distance to encode spatial relationships. + + 算法步骤 / Algorithm Steps: + 1. 计算配对距离向量 d_lm = ref_pos_l - ref_pos_m + 2. 编码逆配对距离: 1/(1 + ||d_lm||^2) + 3-4. 嵌入mask信息并应用到特征上 + + 参数 / Parameters: + c_atompair: 配对特征维度 / Pairwise feature dimension + embed_frame: 是否嵌入完整的距离向量(3D) / Whether to embed full distance vector (3D) + """ def __init__(self, c_atompair, embed_frame=True): super().__init__() self.embed_frame = embed_frame if embed_frame: + # 嵌入完整的3D距离向量 / Embed full 3D distance vector self.process_d = linearNoBias(3, c_atompair) + # Algorithm 14 步骤2: 编码逆配对距离 / Step 2: Encode inverse pairwise distance self.process_inverse_dist = linearNoBias(1, c_atompair) + # Algorithm 14 步骤3-4: 嵌入mask / Step 3-4: Embed mask self.process_valid_mask = linearNoBias(1, c_atompair) def forward_af3(self, D_LL, V_LL): @@ -98,33 +118,60 @@ def forward_af3(self, D_LL, V_LL): return P_LL def forward(self, ref_pos, valid_mask): - D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3) - V_LL = valid_mask + """ + 算法步骤 / Algorithm steps (Algorithm 14): + 1. 计算配对距离 / Compute pairwise distances + 2. 编码逆配对距离 / Encode inverse pairwise distance + 3-4. 嵌入mask / Embed mask + """ + # 步骤1: 计算配对距离向量 / Step 1: Compute pairwise distance vectors + D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3) # [L, L, 3] or [B, L, L, 3] + V_LL = valid_mask # [L, L, 1] or [B, L, L, 1] if self.embed_frame: - # Embed pairwise distances + # 嵌入完整的3D距离框架 / Embed full 3D distance frame return self.forward_af3(D_LL, V_LL) + + # 步骤2: 编码逆配对距离 / Step 2: Encode inverse pairwise distance + # 计算距离的平方: ||d_lm||^2 norm = torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2 - norm = torch.clamp(norm, min=1e-6) + norm = torch.clamp(norm, min=1e-6) # 避免除以零 / Avoid division by zero + # 逆距离编码: 1/(1 + ||d_lm||^2) inv_dist = 1 / (1 + norm) P_LL = self.process_inverse_dist(inv_dist) * V_LL + + # 步骤3-4: 添加mask嵌入 / Step 3-4: Add mask embedding P_LL = P_LL + self.process_valid_mask(V_LL.to(P_LL.dtype)) * V_LL return P_LL class OneDFeatureEmbedder(nn.Module): """ - Embeds 1D features into a single vector. + 一维特征嵌入器 (Algorithm 15: One-dimension Feature Embedder) + One-dimension Feature Embedder + + 将多个1D特征(残基类型、原子类型等)嵌入并求和为单个向量。 + Embeds and sums multiple 1D features (residue type, atom type, etc.) into a single vector. - Args: - features (dict): Dictionary of feature names and their number of channels. - output_channels (int): Output dimension of the projected embedding. + 算法步骤 / Algorithm Steps: + 1. 对每个1D特征f_i进行独立嵌入: e_i = Embed(f_i) + 2. 求和所有嵌入: e = Σ e_i + + 这种加法聚合允许模型组合多个特征源。 + This additive aggregation allows the model to compose multiple feature sources. + + 参数 / Args: + features (dict): 特征名称及其通道数的字典 / Dictionary of feature names and their number of channels + output_channels (int): 输出嵌入维度 / Output dimension of the projected embedding """ def __init__(self, features, output_channels): super().__init__() + # 过滤存在的特征 / Filter existing features self.features = {k: v for k, v in features.items() if exists(v)} total_embedding_input_features = sum(self.features.values()) + + # 为每个特征创建独立的嵌入层 / Create independent embedding layer for each feature self.embedders = nn.ModuleDict( { feature: EmbeddingLayer( @@ -135,6 +182,17 @@ def __init__(self, features, output_channels): ) def forward(self, f, collapse_length): + """ + 前向传播 / Forward pass + + 参数 / Args: + f: 特征字典 / Feature dictionary + collapse_length: 折叠长度(I或L) / Collapse length (I or L) + + 返回 / Returns: + 嵌入特征的总和 / Sum of embedded features: [collapse_length, output_channels] + """ + # Algorithm 15: 对每个1D特征嵌入并求和 / Embed each 1D feature and sum return sum( tuple( self.embedders[feature](collapse(f[feature].float(), collapse_length)) @@ -146,37 +204,54 @@ def forward(self, f, collapse_length): class SinusoidalDistEmbed(nn.Module): """ - Applies sinusoidal embedding to pairwise distances and projects to c_atompair. - - Args: - c_atompair (int): Output dimension of the projected embedding (must be even). + 正弦距离嵌入 (Algorithm 13: Sinusoidal Distance Embedding) + Sinusoidal Distance Embedding + + 对配对距离应用正弦嵌入,类似于Transformer中的位置编码。 + Applies sinusoidal embedding to pairwise distances, similar to positional encoding in Transformers. + + 算法步骤 / Algorithm Steps: + 1. 计算配对距离 ||p_l - p_m|| + 2-4. 正弦嵌入: + - 频率: ω_k = 1 / (10000^(2k/D)) + - 角度: θ_lmk = d_lm * ω_k + - 嵌入: e_lm = [sin(θ) || cos(θ)] + 5-7. 应用并嵌入mask + + 参数 / Args: + c_atompair (int): 输出投影嵌入维度(必须为偶数) / Output dimension (must be even) + n_freqs (int): sin/cos对数,总正弦维度 = 2 * n_freqs / Number of sin/cos pairs """ def __init__(self, c_atompair, n_freqs=32): super().__init__() assert c_atompair % 2 == 0, "Output embedding dim must be even" - self.n_freqs = ( - n_freqs # Number of sin/cos pairs → total sinusoidal dim = 2 * n_freqs - ) + self.n_freqs = n_freqs # Number of sin/cos pairs → total sinusoidal dim = 2 * n_freqs self.c_atompair = c_atompair + # 投影正弦嵌入到输出维度 / Project sinusoidal embedding to output dimension self.output_proj = linearNoBias(2 * n_freqs, c_atompair) + # Algorithm 13 步骤5-7: 嵌入mask / Step 5-7: Embed mask self.process_valid_mask = linearNoBias(1, c_atompair) def forward(self, pos, valid_mask): """ - Args: - pos: [L, 3] or [B, L, 3] ground truth atom positions - valid_mask: [L, L, 1] or [B, L, L, 1] boolean mask - Returns: - P_LL: [L, L, c_atompair] or [B, L, L, c_atompair] + 前向传播 / Forward pass + + 参数 / Args: + pos: [L, 3] 或 [B, L, 3] 原子位置 / Atom positions + valid_mask: [L, L, 1] 或 [B, L, L, 1] 有效性mask / Validity mask + + 返回 / Returns: + P_LL: [L, L, c_atompair] 或 [B, L, L, c_atompair] 嵌入的配对特征 """ - # Compute pairwise distances + # ===== 步骤1: 计算配对距离 / Step 1: Compute pairwise distances ===== D_LL = pos.unsqueeze(-2) - pos.unsqueeze(-3) # [L, L, 3] or [B, L, L, 3] dist_matrix = torch.linalg.norm(D_LL, dim=-1) # [L, L] or [B, L, L] - # Sinusoidal embedding + # ===== 步骤2-4: 正弦嵌入 / Step 2-4: Sinusoidal embedding ===== + # 步骤2: 计算频率 ω_k = 1 / (10000^(2k/D)) half_dim = self.n_freqs freq = torch.exp( -math.log(10000.0) @@ -184,16 +259,19 @@ def forward(self, pos, valid_mask): / half_dim ).to(dist_matrix.device) # [n_freqs] - angles = dist_matrix.unsqueeze(-1) * freq # [..., D/2] - sin_embed = torch.sin(angles) - cos_embed = torch.cos(angles) - sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [..., D] + # 步骤3: 计算角度 θ_lmk = d_lm * ω_k + angles = dist_matrix.unsqueeze(-1) * freq # [..., n_freqs] + + # 步骤4: 应用sin和cos生成嵌入 / Apply sin and cos to generate embedding + sin_embed = torch.sin(angles) # [..., n_freqs] + cos_embed = torch.cos(angles) # [..., n_freqs] + sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [..., 2*n_freqs] - # Linear projection + # 线性投影到输出维度 / Linear projection to output dimension P_LL = self.output_proj(sincos_embed) # [..., c_atompair] P_LL = P_LL * valid_mask - # Add linear embedding of valid mask + # ===== 步骤5-7: 添加mask嵌入 / Step 5-7: Add mask embedding ===== P_LL = P_LL + self.process_valid_mask(valid_mask.to(P_LL.dtype)) * valid_mask return P_LL @@ -516,7 +594,26 @@ def forward(self, Q_L, A_I, tok_idx): class Downcast(nn.Module): - """Downcast modules for when atoms are already reshaped from N_atoms -> (N_tokens, 14)""" + """ + 下投影 (Algorithm 9: Downcast) + Downcast + + 将atom级特征池化到token级特征。使用交叉注意力或平均池化。 + Pools atom-level features to token-level features using cross-attention or mean pooling. + + 算法步骤 / Algorithm Steps (Algorithm 9): + 1. 按token ID分组原子: group_atoms(q_ia) + 2. 交叉注意力池化: GatedCrossAttention (Q=a_i, KV=q_ia) + 或平均池化: mean(q_ia) per token + 3. (可选) 添加单轨迹特征 s_i + 4. 返回更新的token特征 a_i + + 参数 / Parameters: + c_atom: Atom特征维度 / Atom feature dimension + c_token: Token特征维度 / Token feature dimension + c_s: (可选) 单轨迹特征维度 / Optional single track feature dimension + method: "mean" (平均池化) 或 "cross_attention" / Pooling method + """ def __init__( self, c_atom, c_token, c_s=None, method="mean", cross_attention_block=None @@ -525,6 +622,8 @@ def __init__( self.method = method self.c_token = c_token self.c_atom = c_atom + + # 可选: 处理单轨迹特征 / Optional: process single track features if c_s is not None: self.process_s = nn.Sequential( RMSNorm((c_s,)), @@ -533,9 +632,12 @@ def __init__( else: self.process_s = None + # 池化方法 / Pooling method if self.method == "mean": + # 平均池化: 投影并求平均 / Mean pooling: project and average self.project = linearNoBias(c_atom, c_token) elif self.method == "cross_attention": + # Algorithm 11: GatedCrossAttention - Q=token, KV=atoms self.gca = GatedCrossAttention( c_query=c_token, c_kv=c_atom, @@ -545,24 +647,58 @@ def __init__( raise ValueError(f"Unknown downcast method: {self.method}") def forward_(self, Q_IA, A_I, S_I=None, valid_mask=None): + """ + 核心Downcast操作 / Core downcast operation + + 参数 / Args: + Q_IA: [B, I, max_atoms, c_atom] 分组的atom特征 / Grouped atom features + A_I: [B, I, c_token] 当前token特征 / Current token features + S_I: [B, I, c_s] (可选) 单轨迹特征 / Optional single track features + valid_mask: [I, max_atoms] 有效atom mask / Valid atom mask + + 返回 / Returns: + A_I: [B, I, c_token] 更新后的token特征 / Updated token features + """ + # ===== Algorithm 9 步骤2: 池化操作 / Step 2: Pooling operation ===== if self.method == "mean": + # 平均池化: project并除以有效atom数 / Mean pooling: project and divide by valid atom count A_I_update = self.project(Q_IA).sum(-2) / valid_mask.sum(-1, keepdim=True) elif self.method == "cross_attention": + # Algorithm 11: GatedCrossAttention assert exists(A_I) and exists(valid_mask) - # Attention mask: ..., 1, n_atom_per_tok (1 querying token to atoms in token) + # Attention mask: ..., 1, n_atom_per_tok (1个查询token对应token内的atoms) attn_mask = valid_mask[..., None, :] A_I_update = self.gca( - q=A_I[..., None, :], kv=Q_IA, attn_mask=attn_mask + q=A_I[..., None, :], # Q: 单个token特征 + kv=Q_IA, # KV: token内的所有atom特征 + attn_mask=attn_mask ).squeeze(-2) + # 残差连接 / Residual connection A_I = A_I + A_I_update if exists(A_I) else A_I_update + # ===== Algorithm 9 步骤4: (可选) 添加单轨迹特征 / Step 4: (Optional) Add single track features ===== if self.process_s is not None: A_I = A_I + self.process_s(S_I) return A_I def forward(self, Q_L, A_I, S_I=None, tok_idx=None): + """ + 前向传播:将atom特征池化到token特征 / Forward: pool atom features to token features + + 参数 / Args: + Q_L: [B, L, c_atom] 或 [L, c_atom] Atom特征 / Atom features + A_I: [B, I, c_token] 或 [I, c_token] 当前token特征 / Current token features + S_I: [B, I, c_s] 或 [I, c_s] (可选) 单轨迹特征 / Optional single track features + tok_idx: [L] atom到token的映射 / Atom to token mapping + + 返回 / Returns: + A_I: 更新后的token特征 / Updated token features + """ + # ===== Algorithm 9 步骤1: 按token ID分组原子 / Step 1: Group atoms by token ID ===== valid_mask = build_valid_mask(tok_idx) + + # 处理批次维度 / Handle batch dimension if Q_L.ndim == 2: squeeze = True Q_L = Q_L.unsqueeze(0) @@ -572,8 +708,10 @@ def forward(self, Q_L, A_I, S_I=None, tok_idx=None): A_I = A_I.unsqueeze(0) if exists(A_I) and A_I.ndim == 2 else A_I S_I = S_I.unsqueeze(0) if exists(S_I) and S_I.ndim == 2 else S_I + # 将atom特征重新组织为 [B, I, max_atoms, c_atom] Q_IA = ungroup_atoms(Q_L, valid_mask) + # 执行池化操作 / Perform pooling operation A_I = self.forward_(Q_IA, A_I, S_I, valid_mask=valid_mask) if squeeze: @@ -587,6 +725,33 @@ def forward(self, Q_L, A_I, S_I=None, tok_idx=None): class LocalTokenTransformer(nn.Module): + """ + 局部Token Transformer (Algorithm 6: Local token transformer) + Local Token Transformer + + 在token级别应用SL2稀疏注意力(序列局部 + 结构局部)。 + Applies SL2 sparse attention (sequence-local + structure-local) at token level. + + 算法步骤 / Algorithm Steps (Algorithm 6): + 1. 创建SL2稀疏注意力索引 (序列局部 + 结构局部) + 2-8. 循环通过多个transformer块: + 4. (可选) Upcast - 如果提供了c_skip + 6. SparseAttentionPairBias - 带配对偏置的稀疏注意力 (Algorithm 8) + 7. ConditionedTransitionBlock - 条件化的transition + + 关键特性 / Key Features: + - SL2稀疏注意力:仅关注序列邻居和结构邻居 + - 内存高效:避免完整的I×I注意力矩阵 + - 配对偏置:Z_II作为注意力偏置引导注意力 + + 参数 / Parameters: + c_token: Token特征维度 / Token feature dimension + c_tokenpair: Token配对特征维度 / Token pairwise feature dimension + c_s: 单轨迹特征维度 / Single track feature dimension + n_block: Transformer块数量 / Number of transformer blocks + n_local_tokens: 序列局部邻居数 (默认8) / Number of sequence-local neighbors + n_keys: 结构局部键数 (默认32) / Number of structure-local keys + """ def __init__( self, c_token, @@ -599,8 +764,9 @@ def __init__( n_keys=32, ): super().__init__() - self.n_local_tokens = n_local_tokens - self.n_keys = n_keys + self.n_local_tokens = n_local_tokens # 序列局部注意力邻居数 + self.n_keys = n_keys # 结构局部注意力键数 + # 创建transformer块栈 / Create transformer block stack self.blocks = nn.ModuleList( [ StructureLocalAtomTransformerBlock( @@ -614,26 +780,45 @@ def __init__( ) def forward(self, A_I, S_I, Z_II, f, X_L, full=False): + """ + 前向传播 / Forward pass + + 参数 / Args: + A_I: [B, I, c_token] Token特征 / Token features + S_I: [B, I, c_s] 单轨迹特征 / Single track features + Z_II: [B, I, I, c_tokenpair] Token配对特征(用作注意力偏置) / Token pair features (as attention bias) + f: 特征字典 / Feature dictionary + X_L: [B, I, 3] Token坐标(C-alpha位置) / Token coordinates (C-alpha positions) + full: 是否使用完整注意力(非稀疏) / Whether to use full attention (non-sparse) + + 返回 / Returns: + A_I: [B, I, c_token] 更新后的token特征 / Updated token features + """ + # ===== Algorithm 6 步骤1: 创建SL2稀疏注意力索引 / Step 1: Create SL2 sparse attention indices ===== + # 结合序列局部和结构局部注意力 indices = create_attention_indices( - X_L=X_L, + X_L=X_L, # 用于计算结构局部邻居 / For computing structure-local neighbors f=f, tok_idx=torch.arange(A_I.shape[1], device=A_I.device), - n_attn_keys=self.n_keys, - n_attn_seq_neighbours=self.n_local_tokens, + n_attn_keys=self.n_keys, # 结构局部键数 / Structure-local keys + n_attn_seq_neighbours=self.n_local_tokens, # 序列局部邻居数 / Sequence-local neighbors ) + # ===== Algorithm 6 步骤2-8: 循环通过transformer块 / Step 2-8: Loop through transformer blocks ===== for i, block in enumerate(self.blocks): - # Set checkpointing + # 设置checkpointing以节省内存 / Set checkpointing to save memory block.attention_pair_bias.use_checkpointing = not DISABLE_CHECKPOINTING - # A_I: [B, L, C_token] - # S_I: [B, L, C_s] - # Z_II: [B, L, L, C_tokenpair] + + # 步骤6: SparseAttentionPairBias (Algorithm 8) + 步骤7: ConditionedTransitionBlock + # A_I: [B, I, c_token] Token特征 + # S_I: [B, I, c_s] 单轨迹特征(用于条件化) + # Z_II: [B, I, I, c_tokenpair] 配对特征(用作注意力偏置) A_I = block( A_I, S_I, Z_II, - indices=indices, - full=full, # (self.training and torch.is_grad_enabled()), # Does not accelerate inference, but memory *does* scale better + indices=indices, # SL2稀疏注意力索引 + full=full, # 是否使用完整注意力(内存换速度) ) return A_I diff --git a/models/rfd3/src/rfd3/model/layers/encoders.py b/models/rfd3/src/rfd3/model/layers/encoders.py index a9596f13..a926e834 100644 --- a/models/rfd3/src/rfd3/model/layers/encoders.py +++ b/models/rfd3/src/rfd3/model/layers/encoders.py @@ -35,7 +35,32 @@ class TokenInitializer(nn.Module): """ - Token embedding module for RFD3 + Token初始化器 (Algorithm 3 & 4: Token and Atom initializer) + Token Initializer + + RFD3的初始化模块,实现SI中的Algorithm 3和4。 + 负责从原始特征生成初始的token级和atom级表示。 + Initialization module for RFD3, implementing Algorithms 3 and 4 from SI. + Responsible for generating initial token-level and atom-level representations from raw features. + + 主要功能 / Main Functions: + 1. Token初始化 (Algorithm 3): + - 嵌入1D特征 (残基类型、配对特征等) + - 相对位置编码 + - Pairformer处理配对特征 + - 生成 S_I (token单特征) 和 Z_II (token配对特征) + + 2. Atom初始化 (Algorithm 4): + - 嵌入atom级1D特征 + - Motif位置和参考坐标编码 + - 生成 Q_L (atom初始特征), C_L (atom条件特征), P_LL (atom配对特征) + + 参数 / Parameters: + c_s: Token单轨迹特征维度 / Token single track feature dimension + c_z: Token配对特征维度 / Token pair feature dimension + c_atom: Atom特征维度 / Atom feature dimension + c_atompair: Atom配对特征维度 / Atom pairwise feature dimension + use_chunked_pll: 是否使用分块P_LL计算(内存优化) / Whether to use chunked P_LL computation """ def __init__( @@ -55,15 +80,18 @@ def __init__( ): super().__init__() - # Store chunked mode flag + # 存储分块模式标志 / Store chunked mode flag self.use_chunked_pll = use_chunked_pll - # Features - self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s) - self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom) + # ===== Algorithm 3: Token Initializer - 1D特征嵌入 / 1D Feature Embedding ===== + # Algorithm 15: OneDFeatureEmbedder - 嵌入原始1D特征 + self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s) # 用于token初始化 + self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom) # 用于atom初始化 self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s) + # Algorithm 9: Downcast - 从atom特征池化到token特征 self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast) + # Algorithm 3 步骤5: Transition层用于特征混合 self.transition_post_token = Transition(c=c_s, n=2) self.transition_post_atom = Transition(c=c_s, n=2) self.process_s_init = nn.Sequential( @@ -71,20 +99,24 @@ def __init__( linearNoBias(c_s, c_s), ) - # Operations to mix into Z_II and S_I - self.to_z_init_i = linearNoBias(c_s, c_z) - self.to_z_init_j = linearNoBias(c_s, c_z) + # ===== Algorithm 3 步骤6-8: 配对特征初始化 / Pair Feature Initialization ===== + # 步骤6-7: 从单特征投影到配对特征 (outer sum) + self.to_z_init_i = linearNoBias(c_s, c_z) # S_I -> Z_II (i维度) + self.to_z_init_j = linearNoBias(c_s, c_z) # S_I -> Z_II (j维度) + # 步骤8: 相对位置编码 self.relative_position_encoding = RelativePositionEncodingWithIndexRemoval( c_z=c_z, **relative_position_encoding ) self.relative_position_encoding2 = RelativePositionEncodingWithIndexRemoval( c_z=c_z, **relative_position_encoding ) + # Token间的化学键编码 self.process_token_bonds = linearNoBias(1, c_z) - # Processing of Z_init + # ===== Algorithm 3 步骤9-12: 配对特征处理 / Pair Feature Processing ===== + # 步骤13-19: 混合多个配对特征源并后处理 self.process_z_init = nn.Sequential( - RMSNorm(c_z * 2), + RMSNorm(c_z * 2), # 处理拼接的配对特征 linearNoBias(c_z * 2, c_z), ) self.transition_1 = nn.ModuleList( @@ -93,9 +125,11 @@ def __init__( Transition(c=c_z, n=2), ] ) + # Algorithm 14: PositionPairDistEmbedder - 参考坐标的距离编码 self.ref_pos_embedder_tok = PositionPairDistEmbedder(c_z, embed_frame=False) - # Pairformer without triangle updates + # ===== Algorithm 3 步骤10-12: Pairformer块 / Pairformer Blocks ===== + # Algorithm 7: TransformerBlock - 全注意力transformer处理token特征 self.transformer_stack = nn.ModuleList( [ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block) @@ -103,20 +137,29 @@ def __init__( ] ) - ############################################################################# - # Token track processing + # ===== Algorithm 4: Atom Initializer - Atom级特征处理 / Atom-level Feature Processing ===== + # 步骤2: 从token特征投影到atom特征 / Project from token features to atom features self.process_s_trunk = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_atom)) + + # 步骤5-7: Atom配对特征的MLP处理 / MLP processing for atom pairwise features + # 步骤5: 处理单atom特征(l维度) / Process single atom features (l dimension) self.process_single_l = nn.Sequential( nn.ReLU(), linearNoBias(c_atom, c_atompair) ) + # 步骤5: 处理单atom特征(m维度) / Process single atom features (m dimension) self.process_single_m = nn.Sequential( nn.ReLU(), linearNoBias(c_atom, c_atompair) ) + # 步骤6: 从token配对特征投影 / Project from token pair features self.process_z = nn.Sequential(RMSNorm(c_z), linearNoBias(c_z, c_atompair)) - # ALWAYS create these MLPs - they will be shared between chunked and standard modes + # ===== Algorithm 4 步骤3-4: 位置和距离嵌入器 / Position and Distance Embedders ===== + # 总是创建这些MLP - 在分块和标准模式间共享 / Always create these MLPs - shared between modes + # Algorithm 13: SinusoidalDistEmbed - Motif位置的正弦距离嵌入 self.motif_pos_embedder = SinusoidalDistEmbed(c_atompair=c_atompair) + # Algorithm 14: PositionPairDistEmbedder - 参考位置的配对距离嵌入 self.ref_pos_embedder = PositionPairDistEmbedder(c_atompair, embed_frame=False) + # 步骤7: 深度MLP用于混合配对特征 / Deep MLP for mixing pairwise features self.pair_mlp = nn.Sequential( nn.ReLU(), linearNoBias(c_atompair, c_atompair), @@ -126,23 +169,28 @@ def __init__( linearNoBias(c_atompair, c_atompair), ) - # Atom pair feature processing + # ===== 分块P_LL计算(内存优化) / Chunked P_LL Computation (Memory Optimization) ===== + # Atom配对特征处理 - 支持标准模式和分块模式 if self.use_chunked_pll: - # Initialize chunked embedders and share the trained MLPs! + # 初始化分块嵌入器并共享已训练的MLP! / Initialize chunked embedders and share trained MLPs! self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder( c_atompair=c_atompair, motif_pos_embedder=ChunkedSinusoidalDistEmbed(c_atompair=c_atompair), ref_pos_embedder=ChunkedPositionPairDistEmbedder( c_atompair, embed_frame=False ), - process_single_l=self.process_single_l, # Share trained parameters! - process_single_m=self.process_single_m, # Share trained parameters! - process_z=self.process_z, # Share trained parameters! - pair_mlp=self.pair_mlp, # Share trained parameters! + process_single_l=self.process_single_l, # 共享训练参数! / Share trained parameters! + process_single_m=self.process_single_m, # 共享训练参数! + process_z=self.process_z, # 共享训练参数! + pair_mlp=self.pair_mlp, # 共享训练参数! ) + # 池化P_LL到token级别 / Pool P_LL to token level self.process_pll = linearNoBias(c_atompair, c_atompair) self.project_pll = linearNoBias(c_atompair, c_z) + # ===== 可选的Atom Transformer / Optional Atom Transformer ===== + # 使用序列局部注意力混合atom条件特征 + # Mix atom conditioning features via sequence-local attention if atom_transformer["n_blocks"] > 0: self.atom_transformer = LocalAtomTransformer( c_atom=c_atom, c_s=None, c_atompair=c_atompair, **atom_transformer @@ -162,36 +210,68 @@ def __init__( def forward(self, f): """ - Provides initial representation for atom and token representations + 生成初始atom和token表示 (Algorithm 3 & 4) + Generate initial atom and token representations + + 给定输入特征字典,生成token级和atom级的初始化表示。 + Given input feature dictionary, generate initial token-level and atom-level representations. + + 参数 / Args: + f: 特征字典,包含: + - atom_to_token_map: [L] atom到token的映射 + - restype: [I] 残基类型 + - ref_pos: [L, 3] 参考坐标 + - motif_pos: [L, 3] Motif坐标 + - 其他1D特征... + + 返回 / Returns: + 包含以下键的字典: + - Q_L_init: [L, c_atom] 初始atom查询特征 + - C_L: [L, c_atom] Atom条件特征 + - P_LL: [L, L, c_atompair] 或 分块嵌入器 - Atom配对特征 + - S_I: [I, c_s] Token单特征 + - Z_II: [I, I, c_z] Token配对特征 """ - tok_idx = f["atom_to_token_map"] - L = len(tok_idx) + tok_idx = f["atom_to_token_map"] # [L] atom到token的映射 + L = len(tok_idx) # 原子总数 / Total number of atoms f["ref_atom_name_chars"] = f["ref_atom_name_chars"].reshape(L, -1) - I = len(f["restype"]) + I = len(f["restype"]) # Token总数 / Total number of tokens def init_tokens(): - # Embed token features - S_I = self.token_1d_embedder(f, I) + """ + Algorithm 3: Token初始化器 / Token initializer + 生成初始的token级单特征(S_I)和配对特征(Z_II) + """ + # ===== 步骤1-4: 嵌入1D特征 / Step 1-4: Embed 1D features ===== + # Algorithm 15: 嵌入token级1D特征(残基类型等) + S_I = self.token_1d_embedder(f, I) # [I, c_s] + # 步骤5: Transition层混合特征 S_I = S_I + self.transition_post_token(S_I) - # Embed atom features and downcast to token features + # 嵌入atom级1D特征并下采样到token级 + # Algorithm 9: Downcast - 从atom池化到token S_I = self.downcast_atom( Q_L=self.atom_1d_embedder_1(f, L), A_I=S_I, tok_idx=tok_idx ) S_I = S_I + self.transition_post_atom(S_I) - S_I = self.process_s_init(S_I) + S_I = self.process_s_init(S_I) # [I, c_s] - # Embed Z_II + # ===== 步骤6-8: 初始化配对特征Z_II / Step 6-8: Initialize pair features Z_II ===== + # 步骤6-7: 从单特征生成配对特征 (outer sum: S_I_i + S_I_j) Z_init_II = self.to_z_init_i(S_I).unsqueeze(-3) + self.to_z_init_j( S_I - ).unsqueeze(-2) + ).unsqueeze(-2) # [I, I, c_z] + # 步骤8: 添加相对位置编码 Z_init_II = Z_init_II + self.relative_position_encoding(f) + # 添加token间的化学键信息 Z_init_II = Z_init_II + self.process_token_bonds( f["token_bonds"].unsqueeze(-1).float() ) - # Embed reference coordinates of ligands - token_id = f["ref_space_uid"][f["is_ca"]] + # ===== 步骤9: 嵌入配体的参考坐标 / Step 9: Embed reference coordinates of ligands ===== + # Algorithm 14: PositionPairDistEmbedder + token_id = f["ref_space_uid"][f["is_ca"]] # C-alpha的token ID + # 创建mask:仅对同一token内的原子对计算距离 valid_mask = (token_id.unsqueeze(-1) == token_id.unsqueeze(-2)).unsqueeze( -1 ) @@ -199,19 +279,22 @@ def init_tokens(): f["ref_pos"][f["is_ca"]], valid_mask ) - # Run a small transformer to provide position encodings to single. + # ===== 步骤10-12: Pairformer transformer栈 / Step 10-12: Pairformer transformer stack ===== + # Algorithm 7: TransformerBlock - 使用全注意力处理配对特征 for block in self.transformer_stack: S_I, Z_init_II = block(S_I, Z_init_II) - # Also cat the relative position encoding and mix + # ===== 步骤13-19: 配对特征后处理 / Step 13-19: Post-process pair features ===== + # 拼接第二个相对位置编码并混合 Z_init_II = torch.cat( [ Z_init_II, self.relative_position_encoding2(f), ], dim=-1, - ) - Z_init_II = self.process_z_init(Z_init_II) + ) # [I, I, c_z * 2] + Z_init_II = self.process_z_init(Z_init_II) # [I, I, c_z] + # 两个Transition层进一步混合 for b in range(2): Z_init_II = Z_init_II + self.transition_1[b](Z_init_II) @@ -219,11 +302,19 @@ def init_tokens(): @activation_checkpointing def init_atoms(S_init_I, Z_init_II): - Q_L_init = self.atom_1d_embedder_2(f, L) - C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :] + """ + Algorithm 4: Atom初始化器 / Atom initializer + 生成atom级特征: Q_L_init, C_L, P_LL + """ + # ===== 步骤1: 嵌入atom级1D特征 / Step 1: Embed atom-level 1D features ===== + # Algorithm 15: OneDFeatureEmbedder for atom features + Q_L_init = self.atom_1d_embedder_2(f, L) # [L, c_atom] + + # ===== 步骤2: 从token特征投影 / Step 2: Project from token features ===== + C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :] # [L, c_atom] if self.use_chunked_pll: - # Chunked mode: return embedder for later sparse computation + # ===== 分块模式:返回嵌入器供后续稀疏计算 / Chunked mode: return embedder for later sparse computation ===== return { "Q_L_init": Q_L_init, "C_L": C_L, @@ -232,22 +323,26 @@ def init_atoms(S_init_I, Z_init_II): "Z_II": Z_init_II, } else: - # Original full P_LL computation - ################################################################################## - # Embed motif coordinates + # ===== 标准模式:完整P_LL计算 / Standard mode: full P_LL computation ===== + + # ===== 步骤3: 嵌入Motif坐标 / Step 3: Embed motif coordinates ===== + # Algorithm 13: SinusoidalDistEmbed - 正弦距离嵌入 + # 仅对固定坐标的motif原子对计算距离 valid_mask = ( f["is_motif_atom_with_fixed_coord"].unsqueeze(-1) & f["is_motif_atom_with_fixed_coord"].unsqueeze(-2) - ).unsqueeze(-1) + ).unsqueeze(-1) # [L, L, 1] P_LL = self.motif_pos_embedder( f["motif_pos"], valid_mask - ) # (L, L, c_atompair) + ) # [L, L, c_atompair] - # Embed ref pos + # ===== 步骤4: 嵌入参考位置 / Step 4: Embed reference positions ===== + # Algorithm 14: PositionPairDistEmbedder + # 仅对同一token内的原子对计算距离 atoms_in_same_token = ( f["ref_space_uid"].unsqueeze(-1) == f["ref_space_uid"].unsqueeze(-2) ).unsqueeze(-1) - # Only consider ref_pos for atoms given seq (otherwise ref_pos is 0, doesn't make sense to compute) + # 仅对给定序列的原子考虑ref_pos (否则ref_pos为0,计算无意义) atoms_has_seq = ( f["is_motif_atom_with_fixed_seq"].unsqueeze(-1) & f["is_motif_atom_with_fixed_seq"].unsqueeze(-2) @@ -255,40 +350,45 @@ def init_atoms(S_init_I, Z_init_II): valid_mask = atoms_in_same_token & atoms_has_seq P_LL = P_LL + self.ref_pos_embedder(f["ref_pos"], valid_mask) - ################################################################################## - + # ===== 步骤5-7: Atom配对特征的MLP处理 / Step 5-7: MLP processing for atom pairwise features ===== + # 步骤5: 添加单atom特征的外积 (outer sum: C_L_l + C_L_m) P_LL = P_LL + ( self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3) ) + # 步骤6: 添加从token配对特征投影的信息 + # 将Z_II [I, I, c_z] 通过tok_idx索引到atom级 [L, L, c_z] P_LL = ( P_LL + self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :] ) + # 步骤7: 深度MLP混合所有配对特征 P_LL = P_LL + self.pair_mlp(P_LL) - P_LL = P_LL.contiguous() + P_LL = P_LL.contiguous() # [L, L, c_atompair] - # Pool P_LL to token level to provide atom-level resolution for token track + # ===== 池化P_LL回token级以提供atom级分辨率 / Pool P_LL to token level ===== + # 将atom配对特征池化为token配对特征,增强Z_II pooled_atom_level_features = pairwise_mean_pool( pairwise_atom_features=self.process_pll(P_LL).unsqueeze(0), atom_to_token_map=tok_idx, I=int(tok_idx.max().item()) + 1, dtype=P_LL.dtype, - ).squeeze(0) + ).squeeze(0) # [I, I, c_atompair] Z_init_II = Z_init_II + self.project_pll(pooled_atom_level_features) - # Mix atom conditioning features via sequence-local attention + # ===== 可选: Atom transformer混合条件特征 / Optional: Atom transformer ===== + # 使用序列局部注意力进一步混合atom特征 if exists(self.atom_transformer): C_L = self.atom_transformer( C_L.unsqueeze(0), None, P_LL, indices=None, f=f, X_L=None ).squeeze(0) return { - "Q_L_init": Q_L_init, - "C_L": C_L, - "P_LL": P_LL, - "S_I": S_init_I, - "Z_II": Z_init_II, + "Q_L_init": Q_L_init, # [L, c_atom] 初始atom查询特征 + "C_L": C_L, # [L, c_atom] Atom条件特征 + "P_LL": P_LL, # [L, L, c_atompair] Atom配对特征 + "S_I": S_init_I, # [I, c_s] Token单特征 + "Z_II": Z_init_II, # [I, I, c_z] Token配对特征(增强后) } tokens = init_tokens() @@ -296,6 +396,25 @@ def init_atoms(S_init_I, Z_init_II): class DiffusionTokenEncoder(nn.Module): + """ + 扩散Token编码器 (Algorithm 12: Diffusion token encoder) + Diffusion Token Encoder + + 在每次扩散循环中嵌入噪声尺度和循环distogram。 + Embeds noise scale and recycled distogram at each diffusion cycle. + + 主要功能 / Main Functions: + - 将当前噪声坐标的distogram编码到Z_II中 + - 将前一次循环的self-conditioning distogram编码到Z_II中 + - 通过Pairformer块混合token单特征和配对特征 + + 参数 / Parameters: + c_s: Token单轨迹特征维度 / Token single track dimension + c_z: Token配对特征维度 / Token pair dimension + sigma_data: EDM数据方差 / EDM data variance + use_distogram: 是否使用当前distogram / Whether to use current distogram + use_self: 是否使用self-conditioning distogram / Whether to use self-conditioning distogram + """ def __init__( self, c_s, @@ -312,7 +431,7 @@ def __init__( ): super().__init__() - # Sequence processing + # ===== Algorithm 12 步骤1-3: Token单特征处理 / Step 1-3: Token single feature processing ===== self.transition_1 = nn.ModuleList( [ Transition(c=c_s, n=2), @@ -320,29 +439,37 @@ def __init__( ] ) - # Post-processing of z - self.n_bins_distogram = 65 # n bins for both self distogram and distogram + # ===== Algorithm 12 步骤4-8: Distogram嵌入和配对特征处理 / Step 4-8: Distogram embedding and pair feature processing ===== + self.n_bins_distogram = 65 # Distogram离散化bin数 (1-30Å) / Number of distogram bins n_bins_noise = self.n_bins_distogram - self.use_self = use_self - self.use_distogram = use_distogram + self.use_self = use_self # 是否使用self-conditioning distogram + self.use_distogram = use_distogram # 是否使用当前噪声distogram self.use_sinusoidal_distogram_embedder = use_sinusoidal_distogram_embedder + + # 步骤4: 离散化或嵌入distogram / Bucketize or embed distogram if self.use_distogram: if self.use_sinusoidal_distogram_embedder: + # Algorithm 13: 使用正弦嵌入 / Use sinusoidal embedding self.dist_embedder = SinusoidalDistEmbed(c_atompair=c_z) n_bins_noise = c_z else: + # 使用离散化 / Use bucketization self.bucketize_fn = functools.partial( bucketize_scaled_distogram, - min_dist=1, - max_dist=30, + min_dist=1, # 最小距离1Å + max_dist=30, # 最大距离30Å sigma_data=sigma_data, n_bins=self.n_bins_distogram, ) + + # 计算拼接后的配对特征维度 / Calculate concatenated pair feature dimension + # Z_II + distogram + self_distogram cat_c_z = ( c_z - + int(self.use_distogram) * n_bins_noise - + int(self.use_self) * self.n_bins_distogram + + int(self.use_distogram) * n_bins_noise # 当前distogram + + int(self.use_self) * self.n_bins_distogram # Self-conditioning distogram ) + # 步骤5-8: 混合拼接的配对特征 / Mix concatenated pair features self.process_z = nn.Sequential( RMSNorm(cat_c_z), linearNoBias(cat_c_z, c_z), @@ -355,7 +482,8 @@ def __init__( ] ) - # Pairformer without triangle updates + # ===== Algorithm 12 步骤9-11: Pairformer块 / Step 9-11: Pairformer blocks ===== + # Algorithm 7: TransformerBlock - 混合单特征和配对特征 self.pairformer_stack = nn.ModuleList( [ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block) @@ -364,51 +492,86 @@ def __init__( ) def forward(self, f, R_L, S_init_I, Z_init_II, C_L, P_LL, **kwargs): - B = R_L.shape[0] """ - Pools atom-level features to token-level features and encodes them into Z_II, S_I and prepares A_I. + 扩散Token编码器前向传播 (Algorithm 12) + Diffusion token encoder forward pass + + 嵌入噪声尺度和循环distogram到token配对特征中。 + Embeds noise scale and recycled distogram into token pair features. + + 参数 / Args: + f: 特征字典 / Feature dictionary + R_L: [B, L, 3] 当前噪声坐标 / Current noisy coordinates + S_init_I: [I, c_s] 初始token单特征 / Initial token single features + Z_init_II: [I, I, c_z] 初始token配对特征 / Initial token pair features + C_L: [B, L, c_atom] Atom条件特征 / Atom conditioning features + P_LL: [L, L, c_atompair] Atom配对特征 / Atom pairwise features + **kwargs: 包含D_II_self (前一次循环的distogram) / Contains D_II_self (previous cycle's distogram) + + 返回 / Returns: + S_I: [B, I, c_s] 更新后的token单特征 / Updated token single features + Z_II: [B, I, I, c_z] 更新后的token配对特征 / Updated token pair features """ + B = R_L.shape[0] @activation_checkpointing def token_embed(S_init_I, Z_init_II): + """ + Algorithm 12的核心实现 / Core implementation of Algorithm 12 + """ + # ===== 步骤1-3: 处理token单特征 / Step 1-3: Process token single features ===== S_I = S_init_I for b in range(2): S_I = S_I + self.transition_1[b](S_I) - Z_II = Z_init_II.unsqueeze(0).expand(B, -1, -1, -1) # B, I, I, c_z + # ===== 步骤4-8: 准备配对特征 / Step 4-8: Prepare pair features ===== + # 扩展到batch维度 / Expand to batch dimension + Z_II = Z_init_II.unsqueeze(0).expand(B, -1, -1, -1) # [B, I, I, c_z] + # 收集要拼接的配对特征 / Collect pair features to concatenate Z_II_list = [Z_II] + + # 步骤4: 嵌入当前噪声坐标的distogram / Step 4: Embed current noisy coordinate distogram if self.use_distogram: - # Noise / self conditioning pair if self.use_sinusoidal_distogram_embedder: + # Algorithm 13: 正弦距离嵌入 / Sinusoidal distance embedding mask = f["is_motif_atom_with_fixed_coord"][f["is_ca"]] - mask = (mask[None, :] != mask[:, None]).unsqueeze( - -1 - ) # remove off-diagonals where distances don't make sense across time + # 移除对角线外不同时间的距离(无意义) + # Remove off-diagonal distances across time (meaningless) + mask = (mask[None, :] != mask[:, None]).unsqueeze(-1) D_LL = self.dist_embedder(R_L[..., f["is_ca"], :], ~mask) else: + # 离散化distogram / Bucketize distogram D_LL = self.bucketize_fn( R_L[..., f["is_ca"], :] - ) # [B, L, I, n_bins] + ) # [B, I, I, n_bins] Z_II_list.append(D_LL) + + # 步骤4: 添加self-conditioning distogram (前一次循环的输出) + # Add self-conditioning distogram (previous cycle's output) if self.use_self: D_II_self = kwargs.get("D_II_self") if D_II_self is None: + # 第一次循环,使用零初始化 / First cycle, use zeros D_II_self = torch.zeros( Z_II.shape[:-1] + (self.n_bins_distogram,), device=Z_II.device, dtype=Z_II.dtype, ) Z_II_list.append(D_II_self) - Z_II = torch.cat(Z_II_list, dim=-1) - # Flatten concatenated dims - Z_II = self.process_z(Z_II) + # 步骤5-8: 拼接并混合配对特征 / Step 5-8: Concatenate and mix pair features + Z_II = torch.cat(Z_II_list, dim=-1) # [B, I, I, c_z + n_bins + n_bins] + + # 投影回c_z维度 / Project back to c_z dimension + Z_II = self.process_z(Z_II) # [B, I, I, c_z] + # 两个Transition层进一步混合 / Two Transition layers for further mixing for b in range(2): Z_II = Z_II + self.transition_2[b](Z_II) - # Pairformer to mix + # ===== 步骤9-11: Pairformer混合单特征和配对特征 / Step 9-11: Pairformer to mix single and pair features ===== + # Algorithm 7: TransformerBlock for block in self.pairformer_stack: S_I, Z_II = block(S_I, Z_II)