-
Notifications
You must be signed in to change notification settings - Fork 7
Add perceptual loss #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
themattinthehatt
merged 28 commits into
paninski-lab:main
from
Xinming-Dai:perceptual_loss2
Feb 27, 2026
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
898d8d4
turned off contrastive loss
Xinming-Dai 857c0b5
added perceptual loss
Xinming-Dai ffee13d
added perceptual loss to tensorboard
Xinming-Dai 149fac7
delete perceptual config and resume_beast.sh
Xinming-Dai efafbaa
delete combined.py
Xinming-Dai 651a16b
add test_alex_perceptual_integration method to test perceptual integr…
Xinming-Dai 8d4a52f
add test for perceptual.py
Xinming-Dai 4a202aa
move _log_step to beast.log_step with level param
Xinming-Dai 7d6e2dc
move import to the top of the file
Xinming-Dai dc0c0b0
import log_step from beast
Xinming-Dai aa95a7d
remove import time
Xinming-Dai d056f02
remove extra logging
Xinming-Dai 23c3077
use log_step
Xinming-Dai ea3db82
move log_step inside the try block
Xinming-Dai d13c5bb
change log_step level to "error"
Xinming-Dai f028040
modify the docstring
Xinming-Dai 02110be
add docstrings and types
Xinming-Dai cdc7b34
parameter extraction and logging infoNCE loss
Xinming-Dai ea88d7e
use self.device
Xinming-Dai 84b7356
remove fallback
Xinming-Dai c626b7c
assert 'perceptual_loss' in kwargs
Xinming-Dai dea4e47
revert all changes in __init__ of VisionTransformer
Xinming-Dai a4270da
set use_infoNCE: False
Xinming-Dai 15e7567
revert changes to __get_package_version
Xinming-Dai 25bd8e6
simplify VisionTransformer to match original style
Xinming-Dai e5e40d3
keep comments
Xinming-Dai 78cdf8e
keep the code for using a randomly initialized model
Xinming-Dai 7a4cecc
Fix formatting and lint issues
Xinming-Dai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import torch | ||
| import torchvision | ||
| from torch import nn | ||
| from typing import Any | ||
| # https://github.com/MLReproHub/SMAE/blob/main/src/loss/perceptual.py | ||
|
|
||
|
|
||
| class Perceptual(nn.Module): | ||
| def __init__(self, *, network: nn.Module, criterion: nn.Module): | ||
| """Initialize perceptual loss module. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| network: feature extractor that maps input images to feature tensors | ||
| criterion: loss function applied to extracted features (e.g. MSELoss) | ||
| """ | ||
| super(Perceptual, self).__init__() | ||
| self.net = network | ||
| self.criterion = criterion | ||
| self.sigmoid = nn.Sigmoid() | ||
|
|
||
| def forward(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | ||
| x_hat_features = self.sigmoid(self.net(x_hat)) | ||
| x_features = self.sigmoid(self.net(x)) | ||
| loss = self.criterion(x_hat_features, x_features) | ||
| return loss | ||
|
|
||
|
|
||
| class AlexPerceptual(Perceptual): | ||
| def __init__(self, *, device: str | torch.device, **kwargs: Any): | ||
| """Perceptual loss using pretrained AlexNet features [Pihlgren et al. 2020]. | ||
|
|
||
| Extracts features from the first five layers of AlexNet (pretrained on ImageNet) | ||
| and computes loss between reconstructed and target feature maps. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| device: device to run the feature extractor on (e.g. 'cuda', 'cpu') | ||
| **kwargs: passed to parent; must include criterion (e.g. nn.MSELoss()) | ||
| """ | ||
| # Load alex net pretrained on IN1k | ||
| alex_net = torchvision.models.alexnet(weights='IMAGENET1K_V1') | ||
| # Extract features after second relu activation | ||
| # Append sigmoid layer to normalize features | ||
| perceptual_net = alex_net.features[:5].to(device) | ||
| # Don't record gradients for the perceptual net, the gradients will still propagate through. | ||
| for parameter in perceptual_net.parameters(): | ||
| parameter.requires_grad = False | ||
|
|
||
| super(AlexPerceptual, self).__init__(network=perceptual_net, **kwargs) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.