Added support for modified sum and difference loss from https://arxiv.org/abs/2208.11428#71
Added support for modified sum and difference loss from https://arxiv.org/abs/2208.11428#71sai-soum wants to merge 4 commits intocsteinmetz1:mainfrom
Conversation
sai-soum
commented
Feb 8, 2024

csteinmetz1
left a comment
There was a problem hiding this comment.
Thanks, this will be nice to add. Main point is that we should avoid creating a new class and extend the existing class to support this behavior if possible. If that doesn't workout, at least make the new class inherit from the main class to avoid any repetitive code.
| elif self.output == "full": | ||
| return loss, sum_loss, diff_loss | ||
|
|
||
| class ModifiedSumAndDifferenceSTFTLoss(torch.nn.Module): |
There was a problem hiding this comment.
Instead of creating a new class for this we should add parameters to SumAndDifferenceSTFTLoss in order to support this behavior. It seems like the major modification is that application of the pre-emphasis filter.
| l1log_sum_loss = self.l1logstft(input_sum_mag, target_sum_mag) | ||
| l1log_diff_loss = self.l1logstft(input_diff_mag, target_diff_mag) | ||
|
|
||
| if self.loss_type == 'SClogL1': |
There was a problem hiding this comment.
Looks like the other difference in the distance measure. This should be able to be supported by the main class. However, if it seems easier, we could consider adding a new ModifiedSumAndDifferenceLoss class but where it inherits from the main class so that we don't get all this repeated code (e.g. stft, etc.)