From 90b328c0a0cfe5d3646c494b4126dab892238950 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 25 Dec 2025 15:36:29 +0000 Subject: [PATCH] Add detailed algorithm citations from RFdiffusion3 SI Add line-by-line citations mapping code to algorithms in ./docs/rf3_si.pdf: - Algorithm 1 (inference_sampler.py): Main inference diffusion loop - Lines 1-12 annotated for the denoising diffusion process - Includes gamma modulation, noise injection, and coordinate updates - Algorithm 2 (inference_sampler.py): Symmetric inference diffusion loop - Lines 1-13 annotated for homo-oligomer design - Key difference: symmetry operations applied during diffusion - Algorithm 3 (encoders.py): Token initializer - Lines 1-12 annotated for token-level feature generation - Includes 1D embedding, pair features, and Pairformer processing - Algorithm 4 (encoders.py): Atom initializer - Lines 1-8 annotated for atom-level feature generation - Includes motif/reference embeddings and MLP processing - Algorithm 5 (RFD3_diffusion_module.py): Diffusion forward pass with recycling - Lines 1-18 annotated for the main diffusion module - Includes encoder, transformer, decoder, and recycling loop All annotations follow the format "Algorithm X - line Y" to match the user's requested style for easy cross-referencing with the paper. --- .../src/rfd3/model/RFD3_diffusion_module.py | 24 ++++++- .../rfd3/src/rfd3/model/inference_sampler.py | 63 ++++++++++++++++--- models/rfd3/src/rfd3/model/layers/encoders.py | 32 +++++++++- 3 files changed, 106 insertions(+), 13 deletions(-) diff --git a/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py b/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py index 0ee0ce69..b9bd8230 100644 --- a/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py +++ b/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py @@ -297,7 +297,8 @@ def forward( **kwargs: Any, ) -> Dict[str, torch.Tensor]: """ - 扩散前向传播 (Algorithm 5: Diffusion forward pass with recycling) + Algorithm 5: Diffusion forward pass with recycling (./docs/rf3_si.pdf) + 扩散前向传播 Diffusion forward pass with recycling 给定噪声坐标和编码特征,计算去噪后的位置。 @@ -317,6 +318,7 @@ def forward( 返回 / Returns: outputs: 包含去噪坐标和序列预测的字典 / Dictionary with denoised coordinates and sequence predictions """ + # Algorithm 5 - line 1: Collect inputs and create attention indices # ===== 步骤1: 收集输入和创建注意力索引 / Step 1: Collect inputs and create attention indices ===== tok_idx = f["atom_to_token_map"] # [L] atom到token的映射 / Atom to token mapping L = len(tok_idx) # 原子总数 / Total number of atoms @@ -331,6 +333,7 @@ def forward( n_attn_seq_neighbours=self.n_attn_seq_neighbours, # 序列局部邻居 (默认32) ) + # Algorithm 5 - line 2: Expand time tensors and mask fixed regions # ===== 步骤2-3: 扩展时间张量并屏蔽固定区域 / Step 2-3: Expand time tensors and mask fixed regions ===== # t_L: [B, L] 每个atom的噪声水平,motif区域为0 t_L = t.unsqueeze(-1).expand(-1, L) * ( @@ -341,22 +344,28 @@ def forward( ~f["is_motif_token_with_fully_fixed_coord"] ).float().unsqueeze(0) + # Algorithm 5 - line 3: Scale positions (EDM preconditioning) # ===== 步骤4: 坐标缩放 (EDM预条件化) / Step 4: Scale positions (EDM preconditioning) ===== R_L_uniform = self.scale_positions_in(X_noisy_L, t) # [B, L, 3] 均匀缩放用于distogram R_noisy_L = self.scale_positions_in(X_noisy_L, t_L) # [B, L, 3] 每atom缩放用于特征 + # Algorithm 5 - line 4: Pool initial representation to token level (Downcast) # ===== 步骤5: 池化初始表示到token级 (Algorithm 9: Downcast) / Step 5: Pool initial representation to token level ===== A_I = self.process_a(R_noisy_L, tok_idx=tok_idx) # [B, I, c_token] 从坐标池化的token特征 S_I = self.downcast_c(C_L, S_I, tok_idx=tok_idx) # [I, c_s] 从atom特征池化的token特征 + # Algorithm 5 - line 5: Add position and time embeddings # ===== 步骤6-7: 添加批次级特征 (时间条件化) / Step 6-7: Add batch-wise features (time conditioning) ===== # Algorithm 5 步骤1: 坐标投影 + 初始化特征 Q_L = Q_L_init.unsqueeze(0) + self.process_r(R_noisy_L) # [B, L, c_atom] + + # Algorithm 5 - line 6: Add time conditioning (Algorithm 17) # Algorithm 17: 添加时间条件化特征 C_L = C_L.unsqueeze(0) + self.process_time_(t_L, i=0) # [B, L, c_atom] atom级 S_I = S_I.unsqueeze(0) + self.process_time_(t_I, i=1) # [B, I, c_s] token级 C_L = C_L + self.process_c(C_L) # [B, L, c_atom] 额外的MLP处理 + # Algorithm 5 - line 7: Local-atom self-attention encoder # ===== 步骤8: Local-Atom Self Attention (编码器) / Step 8: Local-Atom Self Attention (encoder) ===== # Algorithm 5 步骤8: 局部atom transformer if chunked_pairwise_embedder is not None: @@ -374,10 +383,12 @@ def forward( # 标准模式:使用完整的P_LL / Standard mode: use full P_LL Q_L = self.encoder(Q_L, C_L, P_LL, indices=f["attn_indices"]) + # Algorithm 5 - line 8: Pool atom features to token level (Downcast) # ===== 步骤9: 池化到token级准备transformer / Step 9: Pool to token level for transformer ===== # Algorithm 9: Downcast - 将atom特征池化为token特征 A_I = self.downcast_q(Q_L, A_I=A_I, S_I=S_I, tok_idx=tok_idx) # [B, I, c_token] + # Algorithm 5 - line 9: for r ∈ [1, ..., n_recycle] do (Recycling loop) # ===== 步骤10-17: 循环处理 (Recycling Loop) / Step 10-17: Recycling loop ===== # Algorithm 5 步骤10-17: 带distogram循环的迭代细化 recycled_features = self.forward_with_recycle( @@ -396,6 +407,8 @@ def forward( initializer_outputs=initializer_outputs, ) + # Algorithm 5 - line 17: end for + # Algorithm 5 - line 18: return x̂0 # ===== 收集输出 / Collect outputs ===== outputs = { "X_L": recycled_features["X_L"], # [B, L, 3] 去噪后的坐标 / Denoised positions @@ -496,6 +509,7 @@ def process_( 返回 / Returns: 包含更新坐标、distogram和序列预测的字典 / Dictionary with updated coordinates, distogram, and sequence predictions """ + # Algorithm 5 - line 10: DiffusionTokenEncoder (Algorithm 12) # ===== 步骤12: DiffusionTokenEncoder - 嵌入噪声尺度和循环distogram ===== # Step 12: DiffusionTokenEncoder - Embed noise scale and recycled distogram # Algorithm 12: 将当前坐标的distogram和前一次循环的distogram嵌入到Z_II中 @@ -509,6 +523,7 @@ def process_( P_LL=P_LL, # [L, L, c_atompair] Atom配对特征 ) + # Algorithm 5 - line 11: DiffusionTransformer (Algorithm 6: LocalTokenTransformer) # ===== 步骤13: DiffusionTransformer - Token级稀疏注意力 ===== # Step 13: DiffusionTransformer - Token-level sparse attention # Algorithm 6: LocalTokenTransformer with SL2 sparse attention @@ -526,6 +541,7 @@ def process_( full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"), # 低内存模式标志 ) + # Algorithm 5 - line 12: Decoder (CompactStreamingDecoder with Upcast) # ===== 步骤14: Decoder - 上投影并解码为结构 ===== # Step 14: Decoder - Up-projection and decode to structure # CompactStreamingDecoder: Token -> Atom特征,包含Algorithm 10 (Upcast) @@ -558,17 +574,23 @@ def process_( indices=f["attn_indices"], ) + # Algorithm 5 - line 13: Project atom features to coordinate update # ===== 步骤15-16: 坐标更新和去预条件化 ===== # Step 15-16: Coordinate update and de-preconditioning # Algorithm 5 步骤15: 投影atom特征到3D坐标更新 R_update_L = self.to_r_update(Q_L) # [B, L, 3] 预测的坐标更新 + + # Algorithm 5 - line 14: De-preconditioning (EDM scaling inverse) # 步骤16: EDM去预条件化,得到去噪后的坐标 X_out_L = self.scale_positions_out(R_update_L, X_noisy_L, t_L) # [B, L, 3] + # Algorithm 5 - line 15: Compute sequence logits and distogram for recycling # ===== 辅助输出:序列预测和distogram ===== # Auxiliary outputs: sequence prediction and distogram # 序列预测头:从token特征预测残基类型 sequence_logits_I, sequence_indices_I = self.sequence_head(A_I=A_I) + + # Algorithm 5 - line 16: Bucketize distogram for next recycling iteration # 计算distogram用于下一次循环的self-conditioning # 使用detach()防止梯度回传到前一次循环 D_II_self = self.bucketize_fn(X_out_L[..., f["is_ca"], :].detach()) # [B, I, I, n_bins] diff --git a/models/rfd3/src/rfd3/model/inference_sampler.py b/models/rfd3/src/rfd3/model/inference_sampler.py index 6ed957e8..b0cbdcdb 100644 --- a/models/rfd3/src/rfd3/model/inference_sampler.py +++ b/models/rfd3/src/rfd3/model/inference_sampler.py @@ -147,10 +147,19 @@ def sample_diffusion_like_af3( ref_initializer_outputs: dict[str, Any] | None, f_ref: dict[str, Any] | None, ) -> dict[str, Any]: + """ + Algorithm 1: Inference diffusion loop (./docs/rf3_si.pdf) + + Main inference loop implementing the denoising diffusion process. + Corresponds to Algorithm 1 in RFdiffusion3 SI. + + Note: Token and atom embedding (Algorithm 1 - lines 1-2) are performed + outside this function and passed via initializer_outputs. + """ # Motif setup to recenter the motif at every step is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"] - # Book-keeping + # Book-keeping: Construct noise schedule noise_schedule = self._construct_inference_noise_schedule( device=coord_atom_lvl_to_be_noised.device, partial_t=f.get("partial_t", None), @@ -159,6 +168,7 @@ def sample_diffusion_like_af3( L = f["ref_element"].shape[0] D = diffusion_batch_size + # Initialize structure at highest noise level X_L = self._get_initial_structure( c0=noise_schedule[0], D=D, @@ -182,6 +192,7 @@ def sample_diffusion_like_af3( threshold_step = (len(noise_schedule) - 1) * self.fraction_of_steps_to_fix_motif + # Algorithm 1 - line 3: for στ ∈ [σ1, ..., σT] do for step_num, (c_t_minus_1, c_t) in enumerate( zip(noise_schedule, noise_schedule[1:]) ): @@ -203,14 +214,16 @@ def sample_diffusion_like_af3( s_trans=self.s_trans if step_num >= threshold_step else 0.0, ) - # Update gamma & step scale + # Algorithm 1 - line 4: Modulation of ODE/SDE + # γ ← γ0 if στ > γmin else 0 gamma = self.gamma_0 if c_t > self.gamma_min else 0 step_scale = self.step_scale - # Compute the value of t_hat + # Algorithm 1 - line 5: σ̂ ← στ-1 (γ + 1) t_hat = c_t_minus_1 * (gamma + 1) - # Noise the coordinates with scaled Gaussian noise + # Algorithm 1 - line 6: Noise injection to diffused components + # ϵl ← λ√(σ̂² - σ²τ-1) · N(0, I3) · f^is_diffused epsilon_L = ( self.noise_scale * torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1)) @@ -219,9 +232,11 @@ def sample_diffusion_like_af3( epsilon_L[..., is_motif_atom_with_fixed_coord, :] = ( 0 # No noise injection for fixed atoms ) + + # Algorithm 1 - line 7: x^noisy_l ← xl + ϵl X_noisy_L = X_L + epsilon_L - # Denoise the coordinates + # Algorithm 1 - line 8: {x̂0} ← DiffusionModule({x^noisy_l}, σ̂, {f*}, qinit_l, cl, plm, sinit_i, zinit_ij) # Handle chunked mode vs standard mode if "chunked_pairwise_embedder" in initializer_outputs: # Chunked mode: explicitly provide P_LL=None @@ -260,6 +275,8 @@ def sample_diffusion_like_af3( delta_L = ( X_noisy_L - X_denoised_L ) / t_hat # gradient of x wrt. t at x_t_hat + + # Algorithm 1 - line 9: dσ ← στ - σ̂ d_t = c_t - t_hat if self.use_classifier_free_guidance and ( @@ -304,6 +321,7 @@ def sample_diffusion_like_af3( ) # shape (D, L,) sequence_entropy_traj.append(seq_entropy) + # Algorithm 1 - line 10: xl ← x^noisy_l + η · dσ · (xl - x̂0)/σ̂ # Update the coordinates, scaled by the step size X_L = X_noisy_L + step_scale * d_t * delta_L @@ -315,6 +333,8 @@ def sample_diffusion_like_af3( X_denoised_L_traj.append(X_denoised_L) t_hats.append(t_hat) + # Algorithm 1 - line 11: end for + if torch.any(is_motif_atom_with_fixed_coord) and self.allow_realignment: # Insert the gt motif at the end X_L, _ = centre_random_augment_around_motif( @@ -331,6 +351,7 @@ def sample_diffusion_like_af3( X_exists_L=is_motif_atom_with_fixed_coord, ) + # Algorithm 1 - line 12: return {xl} return dict( X_L=X_L, # (D, L, 3) X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]] @@ -344,8 +365,10 @@ def sample_diffusion_like_af3( class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif): """ + Algorithm 2: Symmetric inference diffusion loop (./docs/rf3_si.pdf) + This class is a wrapper around the SampleDiffusionWithMotif class. - It is used to sample diffusion with symmetry. + It is used to sample diffusion with symmetry for homo-oligomers. """ def __init__(self, sym_step_frac: float = 0.9, **kwargs): @@ -382,6 +405,12 @@ def sample_diffusion_like_af3( f_ref: dict[str, Any] | None, **_, ) -> dict[str, Any]: + """ + Algorithm 2: Symmetric inference diffusion loop (./docs/rf3_si.pdf) + + Symmetric variant of the inference loop for homo-oligomer design. + Applies symmetry operations to maintain symmetry throughout diffusion. + """ # Motif setup to recenter the motif at every step is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"] # Book-keeping @@ -413,6 +442,8 @@ def sample_diffusion_like_af3( ranked_logger.info(f"gamma_min_sym: {gamma_min_sym}") ranked_logger.info(f"gamma_min: {self.gamma_min}") + + # Algorithm 2 - line 3: for στ ∈ [σ1, ..., σT] do for step_num, (c_t_minus_1, c_t) in enumerate( zip(noise_schedule, noise_schedule[1:]) ): @@ -428,14 +459,16 @@ def sample_diffusion_like_af3( is_motif_atom_with_fixed_coord, ) - # Update gamma & step scale + # Algorithm 2 - line 4: Modulation of ODE/SDE + # γ ← γ0 if στ > γmin else 0 gamma = self.gamma_0 if c_t > self.gamma_min else 0 step_scale = self.step_scale - # Compute the value of t_hat + # Algorithm 2 - line 5: σ̂ ← στ-1 (γ + 1) t_hat = c_t_minus_1 * (gamma + 1) - # Noise the coordinates with scaled Gaussian noise + # Algorithm 2 - line 6: Noise injection to diffused components + # ϵl ← λ√(σ̂² - σ²τ-1) · N(0, I3) · f^is_diffused epsilon_L = ( self.noise_scale * torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1)) @@ -445,9 +478,11 @@ def sample_diffusion_like_af3( 0 # No noise injection for fixed atoms ) + # Algorithm 2 - line 7: x^noisy_l ← xl + ϵl # NOTE: no symmetry applied to the noisy structure X_noisy_L = X_L + epsilon_L + # Algorithm 2 - line 8: {x̂0} ← DiffusionModule({x^noisy_l}, σ̂, {f*}, qinit_l, cl, plm, sinit_i, zinit_ij) # Denoise the coordinates # Handle chunked mode vs standard mode (same as default sampler) if "chunked_pairwise_embedder" in initializer_outputs: @@ -480,7 +515,9 @@ def sample_diffusion_like_af3( f=f, **initializer_outputs, ) - # apply symmetry to X_denoised_L + + # Algorithm 2 - line 9: Apply symmetry operation (key difference from Algorithm 1) + # x̂0 ← ApplySymmetry(x̂0) for στ > σsym if "X_L" in outs and c_t > gamma_min_sym: # outs["original_X_L"] = outs["X_L"].clone() outs["X_L"] = self.apply_symmetry_to_X_L(outs["X_L"], f) @@ -491,6 +528,8 @@ def sample_diffusion_like_af3( delta_L = ( X_noisy_L - X_denoised_L ) / t_hat # gradient of x wrt. t at x_t_hat + + # Algorithm 2 - line 10: dσ ← στ - σ̂ d_t = c_t - t_hat # NOTE: no classifier-free guidance for symmetry @@ -505,6 +544,7 @@ def sample_diffusion_like_af3( ) # shape (D, L,) sequence_entropy_traj.append(seq_entropy) + # Algorithm 2 - line 11: xl ← x^noisy_l + η · dσ · (xl - x̂0)/σ̂ # Update the coordinates, scaled by the step size # delta_L should be symmetric X_L = X_noisy_L + step_scale * d_t * delta_L @@ -517,6 +557,8 @@ def sample_diffusion_like_af3( X_denoised_L_traj.append(X_denoised_L) t_hats.append(t_hat) + # Algorithm 2 - line 12: end for + if torch.any(is_motif_atom_with_fixed_coord) and self.allow_realignment: # Insert the gt motif at the end X_L, R = centre_random_augment_around_motif( @@ -536,6 +578,7 @@ def sample_diffusion_like_af3( X_exists_L=is_motif_atom_with_fixed_coord, ) + # Algorithm 2 - line 13: return {xl} return dict( X_L=X_L, # (D, L, 3) X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]] diff --git a/models/rfd3/src/rfd3/model/layers/encoders.py b/models/rfd3/src/rfd3/model/layers/encoders.py index a926e834..1749beb1 100644 --- a/models/rfd3/src/rfd3/model/layers/encoders.py +++ b/models/rfd3/src/rfd3/model/layers/encoders.py @@ -239,35 +239,47 @@ def forward(self, f): def init_tokens(): """ - Algorithm 3: Token初始化器 / Token initializer + Algorithm 3: Token初始化器 / Token initializer (./docs/rf3_si.pdf) 生成初始的token级单特征(S_I)和配对特征(Z_II) + Generate initial token-level single features (S_I) and pair features (Z_II) """ + # Algorithm 3 - line 1: Embed token 1D features # ===== 步骤1-4: 嵌入1D特征 / Step 1-4: Embed 1D features ===== # Algorithm 15: 嵌入token级1D特征(残基类型等) S_I = self.token_1d_embedder(f, I) # [I, c_s] + + # Algorithm 3 - line 2: Transition layer for feature mixing # 步骤5: Transition层混合特征 S_I = S_I + self.transition_post_token(S_I) + # Algorithm 3 - line 3: Embed atom 1D features and downcast to token level # 嵌入atom级1D特征并下采样到token级 # Algorithm 9: Downcast - 从atom池化到token S_I = self.downcast_atom( Q_L=self.atom_1d_embedder_1(f, L), A_I=S_I, tok_idx=tok_idx ) + + # Algorithm 3 - line 4: Transition layer and normalization S_I = S_I + self.transition_post_atom(S_I) S_I = self.process_s_init(S_I) # [I, c_s] + # Algorithm 3 - line 5: Initialize pair features Z_II with outer sum # ===== 步骤6-8: 初始化配对特征Z_II / Step 6-8: Initialize pair features Z_II ===== # 步骤6-7: 从单特征生成配对特征 (outer sum: S_I_i + S_I_j) Z_init_II = self.to_z_init_i(S_I).unsqueeze(-3) + self.to_z_init_j( S_I ).unsqueeze(-2) # [I, I, c_z] + + # Algorithm 3 - line 6: Add relative position encoding # 步骤8: 添加相对位置编码 Z_init_II = Z_init_II + self.relative_position_encoding(f) + # Algorithm 3 - line 7: Add token bond information # 添加token间的化学键信息 Z_init_II = Z_init_II + self.process_token_bonds( f["token_bonds"].unsqueeze(-1).float() ) + # Algorithm 3 - line 8: Embed reference coordinates # ===== 步骤9: 嵌入配体的参考坐标 / Step 9: Embed reference coordinates of ligands ===== # Algorithm 14: PositionPairDistEmbedder token_id = f["ref_space_uid"][f["is_ca"]] # C-alpha的token ID @@ -279,11 +291,13 @@ def init_tokens(): f["ref_pos"][f["is_ca"]], valid_mask ) + # Algorithm 3 - line 9: Pairformer transformer stack # ===== 步骤10-12: Pairformer transformer栈 / Step 10-12: Pairformer transformer stack ===== # Algorithm 7: TransformerBlock - 使用全注意力处理配对特征 for block in self.transformer_stack: S_I, Z_init_II = block(S_I, Z_init_II) + # Algorithm 3 - line 10: Concatenate second relative position encoding and process # ===== 步骤13-19: 配对特征后处理 / Step 13-19: Post-process pair features ===== # 拼接第二个相对位置编码并混合 Z_init_II = torch.cat( @@ -294,22 +308,28 @@ def init_tokens(): dim=-1, ) # [I, I, c_z * 2] Z_init_II = self.process_z_init(Z_init_II) # [I, I, c_z] + + # Algorithm 3 - line 11: Apply transition layers for final mixing # 两个Transition层进一步混合 for b in range(2): Z_init_II = Z_init_II + self.transition_1[b](Z_init_II) + # Algorithm 3 - line 12: return S_I, Z_II return {"S_init_I": S_I, "Z_init_II": Z_init_II} @activation_checkpointing def init_atoms(S_init_I, Z_init_II): """ - Algorithm 4: Atom初始化器 / Atom initializer + Algorithm 4: Atom初始化器 / Atom initializer (./docs/rf3_si.pdf) 生成atom级特征: Q_L_init, C_L, P_LL + Generate atom-level features: Q_L_init, C_L, P_LL """ + # Algorithm 4 - line 1: Embed atom-level 1D features # ===== 步骤1: 嵌入atom级1D特征 / Step 1: Embed atom-level 1D features ===== # Algorithm 15: OneDFeatureEmbedder for atom features Q_L_init = self.atom_1d_embedder_2(f, L) # [L, c_atom] + # Algorithm 4 - line 2: Project from token features and add to atom features # ===== 步骤2: 从token特征投影 / Step 2: Project from token features ===== C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :] # [L, c_atom] @@ -325,6 +345,7 @@ def init_atoms(S_init_I, Z_init_II): else: # ===== 标准模式:完整P_LL计算 / Standard mode: full P_LL computation ===== + # Algorithm 4 - line 3: Embed motif coordinates # ===== 步骤3: 嵌入Motif坐标 / Step 3: Embed motif coordinates ===== # Algorithm 13: SinusoidalDistEmbed - 正弦距离嵌入 # 仅对固定坐标的motif原子对计算距离 @@ -336,6 +357,7 @@ def init_atoms(S_init_I, Z_init_II): f["motif_pos"], valid_mask ) # [L, L, c_atompair] + # Algorithm 4 - line 4: Embed reference positions # ===== 步骤4: 嵌入参考位置 / Step 4: Embed reference positions ===== # Algorithm 14: PositionPairDistEmbedder # 仅对同一token内的原子对计算距离 @@ -350,18 +372,23 @@ def init_atoms(S_init_I, Z_init_II): valid_mask = atoms_in_same_token & atoms_has_seq P_LL = P_LL + self.ref_pos_embedder(f["ref_pos"], valid_mask) + # Algorithm 4 - line 5: Add outer sum of single atom features # ===== 步骤5-7: Atom配对特征的MLP处理 / Step 5-7: MLP processing for atom pairwise features ===== # 步骤5: 添加单atom特征的外积 (outer sum: C_L_l + C_L_m) P_LL = P_LL + ( self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3) ) + + # Algorithm 4 - line 6: Add projected token pair features # 步骤6: 添加从token配对特征投影的信息 # 将Z_II [I, I, c_z] 通过tok_idx索引到atom级 [L, L, c_z] P_LL = ( P_LL + self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :] ) + + # Algorithm 4 - line 7: Deep MLP to mix all pairwise features # 步骤7: 深度MLP混合所有配对特征 P_LL = P_LL + self.pair_mlp(P_LL) P_LL = P_LL.contiguous() # [L, L, c_atompair] @@ -383,6 +410,7 @@ def init_atoms(S_init_I, Z_init_II): C_L.unsqueeze(0), None, P_LL, indices=None, f=f, X_L=None ).squeeze(0) + # Algorithm 4 - line 8: return Q_L, C_L, P_LL, S_I, Z_II return { "Q_L_init": Q_L_init, # [L, c_atom] 初始atom查询特征 "C_L": C_L, # [L, c_atom] Atom条件特征