diff --git a/models/rf3/src/rf3/trainers/rf3.py b/models/rf3/src/rf3/trainers/rf3.py index 2d1be4f7..63db8e27 100644 --- a/models/rf3/src/rf3/trainers/rf3.py +++ b/models/rf3/src/rf3/trainers/rf3.py @@ -24,15 +24,6 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True) -def _remap_outputs( - xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"] -) -> Float[torch.Tensor, "D L 3"]: - """Helper function to remap outputs using a mapping tensor.""" - for i in range(xyz.shape[0]): - xyz[i, mapping[i]] = xyz[i].clone() - return xyz - - class RF3Trainer(FabricTrainer): """Standard Trainer for AF3-style models""" diff --git a/models/rfd3/src/rfd3/trainer/rfd3.py b/models/rfd3/src/rfd3/trainer/rfd3.py index 2af7685f..af459708 100644 --- a/models/rfd3/src/rfd3/trainer/rfd3.py +++ b/models/rfd3/src/rfd3/trainer/rfd3.py @@ -1,6 +1,7 @@ import numpy as np import torch from beartype.typing import Any, List, Union +from jaxtyping import Float, Int from biotite.structure import AtomArray, AtomArrayStack from biotite.structure.residues import get_residue_starts from einops import repeat @@ -29,6 +30,15 @@ global_logger = RankedLogger(__name__, rank_zero_only=False) +def _remap_outputs( + xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"] +) -> Float[torch.Tensor, "D L 3"]: + """Helper function to remap outputs using a mapping tensor.""" + for i in range(xyz.shape[0]): + xyz[i, mapping[i]] = xyz[i].clone() + return xyz + + class AADesignTrainer(FabricTrainer): """Mostly for unique things like saving outputs and parsing inputs