-
Notifications
You must be signed in to change notification settings - Fork 93
refactor: move _remap_outputs function to rfd3.py and remove from rf3.py #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: production
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import numpy as np | ||
| import torch | ||
| from beartype.typing import Any, List, Union | ||
| from jaxtyping import Float, Int | ||
| from biotite.structure import AtomArray, AtomArrayStack | ||
| from biotite.structure.residues import get_residue_starts | ||
| from einops import repeat | ||
|
|
@@ -29,6 +30,15 @@ | |
| global_logger = RankedLogger(__name__, rank_zero_only=False) | ||
|
|
||
|
|
||
| def _remap_outputs( | ||
| xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"] | ||
| ) -> Float[torch.Tensor, "D L 3"]: | ||
| """Helper function to remap outputs using a mapping tensor.""" | ||
| for i in range(xyz.shape[0]): | ||
| xyz[i, mapping[i]] = xyz[i].clone() | ||
| return xyz | ||
|
Comment on lines
+33
to
+39
|
||
|
|
||
|
|
||
| class AADesignTrainer(FabricTrainer): | ||
| """Mostly for unique things like saving outputs and parsing inputs | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The remapping logic is incorrect. The current implementation assigns the cloned original values to positions specified by the mapping, which doesn't perform the intended remapping operation. The line should be
xyz[i] = xyz[i, mapping[i]]orxyz[i] = xyz[i].clone()[mapping[i]]to correctly reorder the elements according to the mapping indices. With the current implementation, if mapping contains indices [2, 0, 1], the function won't reorder the tensor elements correctly.