fix reduction mistake in SpectralConvergenceLoss#75
Open
renared wants to merge 1 commit intocsteinmetz1:mainfrom
Open
fix reduction mistake in SpectralConvergenceLoss#75renared wants to merge 1 commit intocsteinmetz1:mainfrom
renared wants to merge 1 commit intocsteinmetz1:mainfrom
Conversation
the denominator was averaged over all dimensions including the batch dimension, see comment by @egaznep in csteinmetz1#69
|
I just stumbled on the exact same problem. Is there any plans on merging this fix? Ping @csteinmetz1? |
cpvlordelo
reviewed
Feb 24, 2025
|
|
||
| def forward(self, x_mag, y_mag): | ||
| return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") | ||
| return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True) |
There was a problem hiding this comment.
Suggested change
| return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True) | |
| return (torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2)) / torch.norm(y_mag, p="fro", dim=(-1, -2))).mean() |
Since you removed the reduction, this is now returning a multi-dimensional tensor. It does work with STFTLoss because the reduction is done inside of it as you can see here, but if you instantiate SpectralConvergenceLoss, on the other hand, then your example code there will crash.
import torch
from auraloss.freq import SpectralConvergenceLoss
batches = [(torch.randn(4, 1, 16384), torch.randn(4, 1, 16384)) for i in range(1024)]
batchall = tuple(torch.concat(u, dim=0) for u in zip(*batches))
loss = SpectralConvergenceLoss()
print("Shape of Spectral Convergence Loss over full dataset:", loss(*batchall).shape)
print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))Before:
Shape of Spectral Convergence Loss over full dataset: torch.Size([])
mean of losses: tensor(1.4144)After:
Shape of Spectral Convergence Loss full dataset: torch.Size([4096, 1, 1])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-45-952d11dbfe6a>](https://localhost:8080/#) in <cell line: 0>()
23 print("Shape of Spectral Convergence Loss over full dataset:", loss(*batchall).shape)
---> 24 print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))
ValueError: only one element tensors can be converted to Python scalarsThis is just a suggestion that will always perform the reduction as mean.
But an even better option, in my opinion, would be to add a new string argument reduction as part of init and call apply_reduction inside this forward method in a similar way done in STFTLoss code.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I noticed that when evaluating the STFT loss over my validation dataset, I obtained different results in function of the batch size. I could isolate the cause to be the spectral convergence term, then came across the comment by @egaznep in issue #69. It does not make sense to average the denominator over all dimensions including the batch dimension, so I believe their suggestion should be used instead.
This snippet shows the difference:
Before:
After: