From d6f861fa5050aa0f60d89971c5aa9b12065452e2 Mon Sep 17 00:00:00 2001 From: Yann Bourdin Date: Wed, 22 May 2024 19:20:51 +0200 Subject: [PATCH] fix reduction mistake in SpectralConvergenceLoss the denominator was averaged over all dimensions including the batch dimension, see comment by @egaznep in #69 --- auraloss/freq.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/auraloss/freq.py b/auraloss/freq.py index 0baf179..9977998 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -5,7 +5,6 @@ from .utils import apply_reduction, get_window from .perceptual import SumAndDifference, FIRFilter - class SpectralConvergenceLoss(torch.nn.Module): """Spectral convergence loss module. @@ -16,7 +15,7 @@ def __init__(self): super(SpectralConvergenceLoss, self).__init__() def forward(self, x_mag, y_mag): - return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + return torch.norm(y_mag - x_mag, p="fro", dim=(-1, -2), keepdim=True) / torch.norm(y_mag, p="fro", dim=(-1, -2), keepdim=True) class STFTMagnitudeLoss(torch.nn.Module):