diff --git a/auraloss/freq.py b/auraloss/freq.py index 0baf179..9977998 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -5,7 +5,6 @@ from .utils import apply_reduction, get_window from .perceptual import SumAndDifference, FIRFilter - class SpectralConvergenceLoss(torch.nn.Module): """Spectral convergence loss module. @@ -16,7 +15,7 @@ def __init__(self): super(SpectralConvergenceLoss, self).__init__() def forward(self, x_mag, y_mag): - return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True) class STFTMagnitudeLoss(torch.nn.Module):