-
Notifications
You must be signed in to change notification settings - Fork 93
refactor: move _remap_outputs function to rfd3.py and remove from rf3.py #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: production
Are you sure you want to change the base?
refactor: move _remap_outputs function to rfd3.py and remove from rf3.py #110
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the _remap_outputs helper function by moving it from rf3.py (where it was unused) to rfd3.py (where it is needed). The function is used to remap network outputs according to a mapping tensor for alignment with ground truth data.
- Removed unused
_remap_outputsfunction frommodels/rf3/src/rf3/trainers/rf3.py - Added
_remap_outputsfunction tomodels/rfd3/src/rfd3/trainer/rfd3.pyalong with requiredjaxtypingimports
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| models/rfd3/src/rfd3/trainer/rfd3.py | Added _remap_outputs function and imported Float and Int from jaxtyping to support the function's type annotations |
| models/rf3/src/rf3/trainers/rf3.py | Removed unused _remap_outputs function definition |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) -> Float[torch.Tensor, "D L 3"]: | ||
| """Helper function to remap outputs using a mapping tensor.""" | ||
| for i in range(xyz.shape[0]): | ||
| xyz[i, mapping[i]] = xyz[i].clone() |
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The remapping logic is incorrect. The current implementation assigns the cloned original values to positions specified by the mapping, which doesn't perform the intended remapping operation. The line should be xyz[i] = xyz[i, mapping[i]] or xyz[i] = xyz[i].clone()[mapping[i]] to correctly reorder the elements according to the mapping indices. With the current implementation, if mapping contains indices [2, 0, 1], the function won't reorder the tensor elements correctly.
| xyz[i, mapping[i]] = xyz[i].clone() | |
| xyz[i] = xyz[i, mapping[i]] |
| def _remap_outputs( | ||
| xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"] | ||
| ) -> Float[torch.Tensor, "D L 3"]: | ||
| """Helper function to remap outputs using a mapping tensor.""" | ||
| for i in range(xyz.shape[0]): | ||
| xyz[i, mapping[i]] = xyz[i].clone() | ||
| return xyz |
Copilot
AI
Dec 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is duplicated from trainer_utils.py (lines 33-39) where an identical implementation already exists. Consider importing the function from trainer_utils instead to avoid code duplication and maintain a single source of truth. If there's a specific reason for the duplication, it should be documented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot open a new pull request to apply changes based on this feedback
|
Was this causing issues? |
* refactor: delete files * fix: add back files that included imported functions in chiral code * chore: add back files to archive, update pyproject.yaml and gitignore
This function is missing in rfd3, but it exists in rf3 and is not used there.