Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .utils import apply_reduction, get_window
from .perceptual import SumAndDifference, FIRFilter


class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module.

Expand All @@ -16,7 +15,7 @@ def __init__(self):
super(SpectralConvergenceLoss, self).__init__()

def forward(self, x_mag, y_mag):
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True)
Copy link

@cpvlordelo cpvlordelo Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True)
return (torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2)) / torch.norm(y_mag, p="fro", dim=(-1, -2))).mean()

Since you removed the reduction, this is now returning a multi-dimensional tensor. It does work with STFTLoss because the reduction is done inside of it as you can see here, but if you instantiate SpectralConvergenceLoss, on the other hand, then your example code there will crash.

import torch
from auraloss.freq import SpectralConvergenceLoss

batches = [(torch.randn(4, 1, 16384), torch.randn(4, 1, 16384)) for i in range(1024)]
batchall = tuple(torch.concat(u, dim=0) for u in zip(*batches))

loss = SpectralConvergenceLoss()
print("Shape of Spectral Convergence Loss over full dataset:", loss(*batchall).shape)
print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))

Before:

Shape of Spectral Convergence Loss over full dataset: torch.Size([])
mean of losses: tensor(1.4144)

After:

Shape of Spectral Convergence Loss full dataset: torch.Size([4096, 1, 1])
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-45-952d11dbfe6a>](https://localhost:8080/#) in <cell line: 0>()
     23 print("Shape of Spectral Convergence Loss over full dataset:", loss(*batchall).shape)
---> 24 print("mean of losses:", torch.mean(torch.tensor(tuple(loss(*batch) for batch in batches))))

ValueError: only one element tensors can be converted to Python scalars

This is just a suggestion that will always perform the reduction as mean.

But an even better option, in my opinion, would be to add a new string argument reduction as part of init and call apply_reduction inside this forward method in a similar way done in STFTLoss code.



class STFTMagnitudeLoss(torch.nn.Module):
Expand Down