Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions models/rf3/src/rf3/trainers/rf3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@
ranked_logger = RankedLogger(__name__, rank_zero_only=True)


def _remap_outputs(
xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
) -> Float[torch.Tensor, "D L 3"]:
"""Helper function to remap outputs using a mapping tensor."""
for i in range(xyz.shape[0]):
xyz[i, mapping[i]] = xyz[i].clone()
return xyz


class RF3Trainer(FabricTrainer):
"""Standard Trainer for AF3-style models"""

Expand Down
10 changes: 10 additions & 0 deletions models/rfd3/src/rfd3/trainer/rfd3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
from beartype.typing import Any, List, Union
from jaxtyping import Float, Int
from biotite.structure import AtomArray, AtomArrayStack
from biotite.structure.residues import get_residue_starts
from einops import repeat
Expand Down Expand Up @@ -29,6 +30,15 @@
global_logger = RankedLogger(__name__, rank_zero_only=False)


def _remap_outputs(
xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
) -> Float[torch.Tensor, "D L 3"]:
"""Helper function to remap outputs using a mapping tensor."""
for i in range(xyz.shape[0]):
xyz[i, mapping[i]] = xyz[i].clone()
Copy link

Copilot AI Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The remapping logic is incorrect. The current implementation assigns the cloned original values to positions specified by the mapping, which doesn't perform the intended remapping operation. The line should be xyz[i] = xyz[i, mapping[i]] or xyz[i] = xyz[i].clone()[mapping[i]] to correctly reorder the elements according to the mapping indices. With the current implementation, if mapping contains indices [2, 0, 1], the function won't reorder the tensor elements correctly.

Suggested change
xyz[i, mapping[i]] = xyz[i].clone()
xyz[i] = xyz[i, mapping[i]]

Copilot uses AI. Check for mistakes.
return xyz
Comment on lines +33 to +39
Copy link

Copilot AI Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is duplicated from trainer_utils.py (lines 33-39) where an identical implementation already exists. Consider importing the function from trainer_utils instead to avoid code duplication and maintain a single source of truth. If there's a specific reason for the duplication, it should be documented.

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback



class AADesignTrainer(FabricTrainer):
"""Mostly for unique things like saving outputs and parsing inputs
Expand Down