Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion models/rfd3/src/rfd3/model/RFD3_diffusion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def forward(
**kwargs: Any,
) -> Dict[str, torch.Tensor]:
"""
扩散前向传播 (Algorithm 5: Diffusion forward pass with recycling)
Algorithm 5: Diffusion forward pass with recycling (./docs/rf3_si.pdf)
扩散前向传播
Diffusion forward pass with recycling

给定噪声坐标和编码特征,计算去噪后的位置。
Expand All @@ -317,6 +318,7 @@ def forward(
返回 / Returns:
outputs: 包含去噪坐标和序列预测的字典 / Dictionary with denoised coordinates and sequence predictions
"""
# Algorithm 5 - line 1: Collect inputs and create attention indices
# ===== 步骤1: 收集输入和创建注意力索引 / Step 1: Collect inputs and create attention indices =====
tok_idx = f["atom_to_token_map"] # [L] atom到token的映射 / Atom to token mapping
L = len(tok_idx) # 原子总数 / Total number of atoms
Expand All @@ -331,6 +333,7 @@ def forward(
n_attn_seq_neighbours=self.n_attn_seq_neighbours, # 序列局部邻居 (默认32)
)

# Algorithm 5 - line 2: Expand time tensors and mask fixed regions
# ===== 步骤2-3: 扩展时间张量并屏蔽固定区域 / Step 2-3: Expand time tensors and mask fixed regions =====
# t_L: [B, L] 每个atom的噪声水平,motif区域为0
t_L = t.unsqueeze(-1).expand(-1, L) * (
Expand All @@ -341,22 +344,28 @@ def forward(
~f["is_motif_token_with_fully_fixed_coord"]
).float().unsqueeze(0)

# Algorithm 5 - line 3: Scale positions (EDM preconditioning)
# ===== 步骤4: 坐标缩放 (EDM预条件化) / Step 4: Scale positions (EDM preconditioning) =====
R_L_uniform = self.scale_positions_in(X_noisy_L, t) # [B, L, 3] 均匀缩放用于distogram
R_noisy_L = self.scale_positions_in(X_noisy_L, t_L) # [B, L, 3] 每atom缩放用于特征

# Algorithm 5 - line 4: Pool initial representation to token level (Downcast)
# ===== 步骤5: 池化初始表示到token级 (Algorithm 9: Downcast) / Step 5: Pool initial representation to token level =====
A_I = self.process_a(R_noisy_L, tok_idx=tok_idx) # [B, I, c_token] 从坐标池化的token特征
S_I = self.downcast_c(C_L, S_I, tok_idx=tok_idx) # [I, c_s] 从atom特征池化的token特征

# Algorithm 5 - line 5: Add position and time embeddings
# ===== 步骤6-7: 添加批次级特征 (时间条件化) / Step 6-7: Add batch-wise features (time conditioning) =====
# Algorithm 5 步骤1: 坐标投影 + 初始化特征
Q_L = Q_L_init.unsqueeze(0) + self.process_r(R_noisy_L) # [B, L, c_atom]

# Algorithm 5 - line 6: Add time conditioning (Algorithm 17)
# Algorithm 17: 添加时间条件化特征
C_L = C_L.unsqueeze(0) + self.process_time_(t_L, i=0) # [B, L, c_atom] atom级
S_I = S_I.unsqueeze(0) + self.process_time_(t_I, i=1) # [B, I, c_s] token级
C_L = C_L + self.process_c(C_L) # [B, L, c_atom] 额外的MLP处理

# Algorithm 5 - line 7: Local-atom self-attention encoder
# ===== 步骤8: Local-Atom Self Attention (编码器) / Step 8: Local-Atom Self Attention (encoder) =====
# Algorithm 5 步骤8: 局部atom transformer
if chunked_pairwise_embedder is not None:
Expand All @@ -374,10 +383,12 @@ def forward(
# 标准模式:使用完整的P_LL / Standard mode: use full P_LL
Q_L = self.encoder(Q_L, C_L, P_LL, indices=f["attn_indices"])

# Algorithm 5 - line 8: Pool atom features to token level (Downcast)
# ===== 步骤9: 池化到token级准备transformer / Step 9: Pool to token level for transformer =====
# Algorithm 9: Downcast - 将atom特征池化为token特征
A_I = self.downcast_q(Q_L, A_I=A_I, S_I=S_I, tok_idx=tok_idx) # [B, I, c_token]

# Algorithm 5 - line 9: for r ∈ [1, ..., n_recycle] do (Recycling loop)
# ===== 步骤10-17: 循环处理 (Recycling Loop) / Step 10-17: Recycling loop =====
# Algorithm 5 步骤10-17: 带distogram循环的迭代细化
recycled_features = self.forward_with_recycle(
Expand All @@ -396,6 +407,8 @@ def forward(
initializer_outputs=initializer_outputs,
)

# Algorithm 5 - line 17: end for
# Algorithm 5 - line 18: return x̂0
# ===== 收集输出 / Collect outputs =====
outputs = {
"X_L": recycled_features["X_L"], # [B, L, 3] 去噪后的坐标 / Denoised positions
Expand Down Expand Up @@ -496,6 +509,7 @@ def process_(
返回 / Returns:
包含更新坐标、distogram和序列预测的字典 / Dictionary with updated coordinates, distogram, and sequence predictions
"""
# Algorithm 5 - line 10: DiffusionTokenEncoder (Algorithm 12)
# ===== 步骤12: DiffusionTokenEncoder - 嵌入噪声尺度和循环distogram =====
# Step 12: DiffusionTokenEncoder - Embed noise scale and recycled distogram
# Algorithm 12: 将当前坐标的distogram和前一次循环的distogram嵌入到Z_II中
Expand All @@ -509,6 +523,7 @@ def process_(
P_LL=P_LL, # [L, L, c_atompair] Atom配对特征
)

# Algorithm 5 - line 11: DiffusionTransformer (Algorithm 6: LocalTokenTransformer)
# ===== 步骤13: DiffusionTransformer - Token级稀疏注意力 =====
# Step 13: DiffusionTransformer - Token-level sparse attention
# Algorithm 6: LocalTokenTransformer with SL2 sparse attention
Expand All @@ -526,6 +541,7 @@ def process_(
full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"), # 低内存模式标志
)

# Algorithm 5 - line 12: Decoder (CompactStreamingDecoder with Upcast)
# ===== 步骤14: Decoder - 上投影并解码为结构 =====
# Step 14: Decoder - Up-projection and decode to structure
# CompactStreamingDecoder: Token -> Atom特征,包含Algorithm 10 (Upcast)
Expand Down Expand Up @@ -558,17 +574,23 @@ def process_(
indices=f["attn_indices"],
)

# Algorithm 5 - line 13: Project atom features to coordinate update
# ===== 步骤15-16: 坐标更新和去预条件化 =====
# Step 15-16: Coordinate update and de-preconditioning
# Algorithm 5 步骤15: 投影atom特征到3D坐标更新
R_update_L = self.to_r_update(Q_L) # [B, L, 3] 预测的坐标更新

# Algorithm 5 - line 14: De-preconditioning (EDM scaling inverse)
# 步骤16: EDM去预条件化,得到去噪后的坐标
X_out_L = self.scale_positions_out(R_update_L, X_noisy_L, t_L) # [B, L, 3]

# Algorithm 5 - line 15: Compute sequence logits and distogram for recycling
# ===== 辅助输出:序列预测和distogram =====
# Auxiliary outputs: sequence prediction and distogram
# 序列预测头:从token特征预测残基类型
sequence_logits_I, sequence_indices_I = self.sequence_head(A_I=A_I)

# Algorithm 5 - line 16: Bucketize distogram for next recycling iteration
# 计算distogram用于下一次循环的self-conditioning
# 使用detach()防止梯度回传到前一次循环
D_II_self = self.bucketize_fn(X_out_L[..., f["is_ca"], :].detach()) # [B, I, I, n_bins]
Expand Down
63 changes: 53 additions & 10 deletions models/rfd3/src/rfd3/model/inference_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,19 @@ def sample_diffusion_like_af3(
ref_initializer_outputs: dict[str, Any] | None,
f_ref: dict[str, Any] | None,
) -> dict[str, Any]:
"""
Algorithm 1: Inference diffusion loop (./docs/rf3_si.pdf)

Main inference loop implementing the denoising diffusion process.
Corresponds to Algorithm 1 in RFdiffusion3 SI.

Note: Token and atom embedding (Algorithm 1 - lines 1-2) are performed
outside this function and passed via initializer_outputs.
"""
# Motif setup to recenter the motif at every step
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]

# Book-keeping
# Book-keeping: Construct noise schedule
noise_schedule = self._construct_inference_noise_schedule(
device=coord_atom_lvl_to_be_noised.device,
partial_t=f.get("partial_t", None),
Expand All @@ -159,6 +168,7 @@ def sample_diffusion_like_af3(
L = f["ref_element"].shape[0]
D = diffusion_batch_size

# Initialize structure at highest noise level
X_L = self._get_initial_structure(
c0=noise_schedule[0],
D=D,
Expand All @@ -182,6 +192,7 @@ def sample_diffusion_like_af3(

threshold_step = (len(noise_schedule) - 1) * self.fraction_of_steps_to_fix_motif

# Algorithm 1 - line 3: for στ ∈ [σ1, ..., σT] do
for step_num, (c_t_minus_1, c_t) in enumerate(
zip(noise_schedule, noise_schedule[1:])
):
Expand All @@ -203,14 +214,16 @@ def sample_diffusion_like_af3(
s_trans=self.s_trans if step_num >= threshold_step else 0.0,
)

# Update gamma & step scale
# Algorithm 1 - line 4: Modulation of ODE/SDE
# γ ← γ0 if στ > γmin else 0
gamma = self.gamma_0 if c_t > self.gamma_min else 0
step_scale = self.step_scale

# Compute the value of t_hat
# Algorithm 1 - line 5: σ̂ ← στ-1 (γ + 1)
t_hat = c_t_minus_1 * (gamma + 1)

# Noise the coordinates with scaled Gaussian noise
# Algorithm 1 - line 6: Noise injection to diffused components
# ϵl ← λ√(σ̂² - σ²τ-1) · N(0, I3) · f^is_diffused
epsilon_L = (
self.noise_scale
* torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
Expand All @@ -219,9 +232,11 @@ def sample_diffusion_like_af3(
epsilon_L[..., is_motif_atom_with_fixed_coord, :] = (
0 # No noise injection for fixed atoms
)

# Algorithm 1 - line 7: x^noisy_l ← xl + ϵl
X_noisy_L = X_L + epsilon_L

# Denoise the coordinates
# Algorithm 1 - line 8: {x̂0} ← DiffusionModule({x^noisy_l}, σ̂, {f*}, qinit_l, cl, plm, sinit_i, zinit_ij)
# Handle chunked mode vs standard mode
if "chunked_pairwise_embedder" in initializer_outputs:
# Chunked mode: explicitly provide P_LL=None
Expand Down Expand Up @@ -260,6 +275,8 @@ def sample_diffusion_like_af3(
delta_L = (
X_noisy_L - X_denoised_L
) / t_hat # gradient of x wrt. t at x_t_hat

# Algorithm 1 - line 9: dσ ← στ - σ̂
d_t = c_t - t_hat

if self.use_classifier_free_guidance and (
Expand Down Expand Up @@ -304,6 +321,7 @@ def sample_diffusion_like_af3(
) # shape (D, L,)
sequence_entropy_traj.append(seq_entropy)

# Algorithm 1 - line 10: xl ← x^noisy_l + η · dσ · (xl - x̂0)/σ̂
# Update the coordinates, scaled by the step size
X_L = X_noisy_L + step_scale * d_t * delta_L

Expand All @@ -315,6 +333,8 @@ def sample_diffusion_like_af3(
X_denoised_L_traj.append(X_denoised_L)
t_hats.append(t_hat)

# Algorithm 1 - line 11: end for

if torch.any(is_motif_atom_with_fixed_coord) and self.allow_realignment:
# Insert the gt motif at the end
X_L, _ = centre_random_augment_around_motif(
Expand All @@ -331,6 +351,7 @@ def sample_diffusion_like_af3(
X_exists_L=is_motif_atom_with_fixed_coord,
)

# Algorithm 1 - line 12: return {xl}
return dict(
X_L=X_L, # (D, L, 3)
X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]]
Expand All @@ -344,8 +365,10 @@ def sample_diffusion_like_af3(

class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
"""
Algorithm 2: Symmetric inference diffusion loop (./docs/rf3_si.pdf)

This class is a wrapper around the SampleDiffusionWithMotif class.
It is used to sample diffusion with symmetry.
It is used to sample diffusion with symmetry for homo-oligomers.
"""

def __init__(self, sym_step_frac: float = 0.9, **kwargs):
Expand Down Expand Up @@ -382,6 +405,12 @@ def sample_diffusion_like_af3(
f_ref: dict[str, Any] | None,
**_,
) -> dict[str, Any]:
"""
Algorithm 2: Symmetric inference diffusion loop (./docs/rf3_si.pdf)

Symmetric variant of the inference loop for homo-oligomer design.
Applies symmetry operations to maintain symmetry throughout diffusion.
"""
# Motif setup to recenter the motif at every step
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
# Book-keeping
Expand Down Expand Up @@ -413,6 +442,8 @@ def sample_diffusion_like_af3(

ranked_logger.info(f"gamma_min_sym: {gamma_min_sym}")
ranked_logger.info(f"gamma_min: {self.gamma_min}")

# Algorithm 2 - line 3: for στ ∈ [σ1, ..., σT] do
for step_num, (c_t_minus_1, c_t) in enumerate(
zip(noise_schedule, noise_schedule[1:])
):
Expand All @@ -428,14 +459,16 @@ def sample_diffusion_like_af3(
is_motif_atom_with_fixed_coord,
)

# Update gamma & step scale
# Algorithm 2 - line 4: Modulation of ODE/SDE
# γ ← γ0 if στ > γmin else 0
gamma = self.gamma_0 if c_t > self.gamma_min else 0
step_scale = self.step_scale

# Compute the value of t_hat
# Algorithm 2 - line 5: σ̂ ← στ-1 (γ + 1)
t_hat = c_t_minus_1 * (gamma + 1)

# Noise the coordinates with scaled Gaussian noise
# Algorithm 2 - line 6: Noise injection to diffused components
# ϵl ← λ√(σ̂² - σ²τ-1) · N(0, I3) · f^is_diffused
epsilon_L = (
self.noise_scale
* torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
Expand All @@ -445,9 +478,11 @@ def sample_diffusion_like_af3(
0 # No noise injection for fixed atoms
)

# Algorithm 2 - line 7: x^noisy_l ← xl + ϵl
# NOTE: no symmetry applied to the noisy structure
X_noisy_L = X_L + epsilon_L

# Algorithm 2 - line 8: {x̂0} ← DiffusionModule({x^noisy_l}, σ̂, {f*}, qinit_l, cl, plm, sinit_i, zinit_ij)
# Denoise the coordinates
# Handle chunked mode vs standard mode (same as default sampler)
if "chunked_pairwise_embedder" in initializer_outputs:
Expand Down Expand Up @@ -480,7 +515,9 @@ def sample_diffusion_like_af3(
f=f,
**initializer_outputs,
)
# apply symmetry to X_denoised_L

# Algorithm 2 - line 9: Apply symmetry operation (key difference from Algorithm 1)
# x̂0 ← ApplySymmetry(x̂0) for στ > σsym
if "X_L" in outs and c_t > gamma_min_sym:
# outs["original_X_L"] = outs["X_L"].clone()
outs["X_L"] = self.apply_symmetry_to_X_L(outs["X_L"], f)
Expand All @@ -491,6 +528,8 @@ def sample_diffusion_like_af3(
delta_L = (
X_noisy_L - X_denoised_L
) / t_hat # gradient of x wrt. t at x_t_hat

# Algorithm 2 - line 10: dσ ← στ - σ̂
d_t = c_t - t_hat

# NOTE: no classifier-free guidance for symmetry
Expand All @@ -505,6 +544,7 @@ def sample_diffusion_like_af3(
) # shape (D, L,)
sequence_entropy_traj.append(seq_entropy)

# Algorithm 2 - line 11: xl ← x^noisy_l + η · dσ · (xl - x̂0)/σ̂
# Update the coordinates, scaled by the step size
# delta_L should be symmetric
X_L = X_noisy_L + step_scale * d_t * delta_L
Expand All @@ -517,6 +557,8 @@ def sample_diffusion_like_af3(
X_denoised_L_traj.append(X_denoised_L)
t_hats.append(t_hat)

# Algorithm 2 - line 12: end for

if torch.any(is_motif_atom_with_fixed_coord) and self.allow_realignment:
# Insert the gt motif at the end
X_L, R = centre_random_augment_around_motif(
Expand All @@ -536,6 +578,7 @@ def sample_diffusion_like_af3(
X_exists_L=is_motif_atom_with_fixed_coord,
)

# Algorithm 2 - line 13: return {xl}
return dict(
X_L=X_L, # (D, L, 3)
X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]]
Expand Down
Loading
Loading