From fdaec80cbe3c256ac7dc9dffe086b4e9356d8a0f Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Tue, 12 Dec 2023 08:37:56 +0000 Subject: [PATCH 01/15] fix stray merge characters --- mst/modules.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mst/modules.py b/mst/modules.py index f5bcf2a..627255f 100644 --- a/mst/modules.py +++ b/mst/modules.py @@ -669,9 +669,6 @@ def forward(self, x: torch.Tensor): class SpectrogramEncoder(torch.nn.Module): - def __init__( - self, -======= def __init__( self, n_inputs=1, From f5875bd4644386bdfd4cc12d796fab796e08ea2b Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Wed, 13 Dec 2023 10:03:05 +0000 Subject: [PATCH 02/15] update jamendo configs --- configs/data/medley+cambridge+jamendo-4.yaml | 26 ++++++++++++++++++++ configs/data/medley+cambridge+jamendo-8.yaml | 4 +-- 2 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 configs/data/medley+cambridge+jamendo-4.yaml diff --git a/configs/data/medley+cambridge+jamendo-4.yaml b/configs/data/medley+cambridge+jamendo-4.yaml new file mode 100644 index 0000000..a324a5a --- /dev/null +++ b/configs/data/medley+cambridge+jamendo-4.yaml @@ -0,0 +1,26 @@ +data: + class_path: mst.dataloader.MultitrackDataModule + init_args: + track_root_dirs: + - /import/c4dm-datasets-ext/mixing-secrets/ + - /import/c4dm-datasets/ + + mix_root_dirs: + - /import/c4dm-datasets-ext/mtg-jamendo + + metadata_files: + - ./data/cambridge.yaml + - ./data/medley.yaml + length: 262144 + + min_tracks: 4 + max_tracks: 4 + batch_size: 4 + num_workers: 4 + num_train_passes: 4 + num_val_passes: 1 + train_buffer_size_gb: 4.0 + val_buffer_size_gb: 0.5 + target_track_lufs_db: -48.0 + randomize_ref_mix_gain: False + diff --git a/configs/data/medley+cambridge+jamendo-8.yaml b/configs/data/medley+cambridge+jamendo-8.yaml index e27be31..0728ba2 100644 --- a/configs/data/medley+cambridge+jamendo-8.yaml +++ b/configs/data/medley+cambridge+jamendo-8.yaml @@ -13,8 +13,8 @@ data: - ./data/medley.yaml length: 262144 - min_tracks: 4 - max_tracks: 4 + min_tracks: 8 + max_tracks: 8 batch_size: 2 num_workers: 4 num_train_passes: 4 From 4a0813197cea2806915e2e1734005ee22bf4dbbd Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Wed, 13 Dec 2023 10:03:27 +0000 Subject: [PATCH 03/15] make sure we search for mp3 files when loading from jamendo --- mst/dataloader.py | 40 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/mst/dataloader.py b/mst/dataloader.py index 2fdca93..d31a41c 100644 --- a/mst/dataloader.py +++ b/mst/dataloader.py @@ -21,12 +21,12 @@ def __init__(self, root_dir: str, length: int = 524288): self.root_dir = root_dir self.length = length - self.mix_filepaths = glob.glob( os.path.join(root_dir, "**", "*.wav"), recursive=True + ) - #self.mix_filepaths = glob.glob( - #os.path.join(root_dir, "**", "*.mp3"), recursive=True) + # self.mix_filepaths = glob.glob( + # os.path.join(root_dir, "**", "*.mp3"), recursive=True) print(f"Located {len(self.mix_filepaths)} mixes.") self.meter = pyln.Meter(44100) @@ -47,8 +47,6 @@ def __getitem__(self, _): offset = np.random.randint(0, num_frames - self.length - 1) offset = 0 # always use the same offset - - mix, _ = torchaudio.load( mix_filepath, frame_offset=offset, @@ -176,12 +174,15 @@ def __init__( for mix_dir in mix_root_dirs: # find all mixes in directory recursively - - mix_files = glob.glob(os.path.join(mix_dir, "**", "*.wav"), recursive=True) - - + ext = "mp3" if "jamendo" in mix_dir else "wav" + mix_files = glob.glob( + os.path.join(mix_dir, "**", f"*.{ext}"), recursive=True + ) self.mixes.extend(mix_files) + if len(mix_root_dirs) > 0 and len(self.mixes) == 0: + raise ValueError("No mixes found in mix_root_dirs.") + print(f"Located {len(self.mixes)} mixes.") # call reload buffer to load initial buffer @@ -214,16 +215,14 @@ def reload_mix_buffer(self): ) if mix.shape[0] == 1: - continue + mix = mix.repeat(2, 1) if mix.shape[-1] != self.length: continue if mix.size()[0] > 2: - continue + mix = mix[0:2, :] mix_lufs_db = self.meter.integrated_loudness(mix.permute(1, 0).numpy()) - - if mix_lufs_db < -48.0 or mix_lufs_db == float("-inf"): continue @@ -293,12 +292,10 @@ def reload_track_buffer(self): if track.size()[0] > 2: continue - track_lufs_db = self.meter.integrated_loudness( track.permute(1, 0).numpy() ) - if track_lufs_db < -48.0 or track_lufs_db == float("-inf"): continue @@ -346,13 +343,14 @@ def reload_track_buffer(self): # convert to tensor tracks = torch.cat(tracks) - # if tracks[...,0:middle_idx].sum() == 0 or tracks[...,middle_idx:].sum() == 0: # continue tracks = tracks.reshape(self.max_tracks, self.length) - #create a sum mix of the tracks + # create a sum mix of the tracks mix_check = tracks.sum(0) - if torch.any(mix_check[...,0:middle_idx] == False) or torch.any(mix_check[...,middle_idx:] == False): + if torch.any(mix_check[..., 0:middle_idx] == False) or torch.any( + mix_check[..., middle_idx:] == False + ): continue track_metadata = torch.tensor(track_metadata) @@ -361,9 +359,7 @@ def reload_track_buffer(self): # add to buffer self.track_examples.append( - (tracks, stereo_info, track_metadata, track_padding, song_name) - ) nbytes_loaded += tracks.element_size() * tracks.nelement() @@ -382,13 +378,11 @@ def __getitem__(self, idx): tracks = track_example[0] - stereo_info = track_example[1] track_metadata = track_example[2] track_padding = track_example[3] song_name = track_example[4] - # ------------ get example from mix buffer ------------ # optional if len(self.mix_examples) > 0: @@ -402,11 +396,9 @@ def __getitem__(self, idx): else: mix = torch.empty(1) - return tracks, stereo_info, track_metadata, track_padding, mix, song_name - class MultitrackDataModule(pl.LightningDataModule): def __init__( self, From 778cf0bd27517cc347b84003413bc9c0e91677f9 Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Wed, 13 Dec 2023 10:04:03 +0000 Subject: [PATCH 04/15] clean up --- mst/modules.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/mst/modules.py b/mst/modules.py index 627255f..6c4cbbc 100644 --- a/mst/modules.py +++ b/mst/modules.py @@ -228,7 +228,6 @@ def forward_mix_console( # apply effects in series but all tracks at once if use_track_input_fader: - tracks = gain(tracks, self.sample_rate, **track_param_dict["input_fader"]) if tracks.sum() == 0: print("gain is 0") @@ -251,7 +250,7 @@ def forward_mix_console( lookahead_samples=2048, ) if tracks.sum() == 0: - print("compressor is 0") + print("compressor is 0") print(tracks) # restore tracks to original shape @@ -286,13 +285,11 @@ def forward_mix_console( # process Left channel master_bus = gain( master_bus, self.sample_rate, **master_bus_param_dict["input_fader"] - ) master_bus = parametric_eq( master_bus, self.sample_rate, **master_bus_param_dict["parametric_eq"] ) - # apply compressor to both channels master_bus = compressor( master_bus, @@ -587,7 +584,6 @@ def forward(self, z: torch.Tensor): class WaveformEncoder(torch.nn.Module): - def __init__( self, n_inputs=1, @@ -752,7 +748,6 @@ def forward(self, x: torch.Tensor): class SpectrogramEncoder(torch.nn.Module): def __init__( self, - embed_dim: int = 128, n_inputs: int = 1, n_fft: int = 2048, @@ -813,7 +808,7 @@ def forward(self, x: torch.torch.Tensor) -> torch.torch.Tensor: # process with CNN embeds = self.model(X) - #print(embeds.shape) + # print(embeds.shape) return embeds @@ -923,4 +918,3 @@ def forward( ) return pred_track_params, pred_fx_bus_params, pred_master_bus_params - From 7d8de3c9fb6b4e72265ba113b81bb55211da5f6b Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Wed, 13 Dec 2023 10:04:33 +0000 Subject: [PATCH 05/15] add a fix for clap loss when logging --- mst/system.py | 118 ++++++-------------------------------------------- 1 file changed, 13 insertions(+), 105 deletions(-) diff --git a/mst/system.py b/mst/system.py index 33f0be0..e814cf4 100644 --- a/mst/system.py +++ b/mst/system.py @@ -50,12 +50,10 @@ def __init__( self.active_master_bus_epoch = active_master_bus_epoch self.meter = pyln.Meter(44100) - #self.warmup = warmup - + # self.warmup = warmup self.save_hyperparameters(ignore=["model", "mix_console", "mix_fn", "loss"]) - # losses for evaluation self.sisdr = auraloss.time.SISDRLoss() self.mrstft = auraloss.freq.MultiResolutionSTFTLoss( @@ -115,7 +113,6 @@ def common_step( """ tracks, instrument_id, stereo_info, track_padding, ref_mix, song_name = batch - #print("song_names from this batch: ", song_name) # split into A and B sections middle_idx = tracks.shape[-1] // 2 @@ -141,11 +138,6 @@ def common_step( ref_fx_bus_param_dict = None ref_master_bus_param_dict = None - # if tracks[...,middle_idx:].sum() == 0: - # print("tracks are zero") - # print(tracks[...,middle_idx:]) - # raise ValueError("input tracks are zero") - # --------- create a random mix (on GPU, if applicable) --------- if self.generate_mix: ( @@ -154,9 +146,9 @@ def common_step( ref_track_param_dict, ref_fx_bus_param_dict, ref_master_bus_param_dict, - ref_mix_params, - ref_fx_bus_params, - ref_master_bus_params + ref_mix_params, + ref_fx_bus_params, + ref_master_bus_params, ) = self.mix_fn( tracks, self.mix_console, @@ -177,13 +169,7 @@ def common_step( ref_mix = batch_stereo_peak_normalize(ref_mix) if torch.isnan(ref_mix).any(): - #print(ref_track_param_dict) raise ValueError("Found nan in ref_mix") - - - # if torch.count_nonzero(ref_mix[...,0:middle_idx])< 1: - # print("ref_mix is zero") - # raise ValueError("ref_mix is zero") ref_mix_a = ref_mix[..., :middle_idx] # this is passed to the model ref_mix_b = ref_mix[..., middle_idx:] # this is used for loss computation @@ -192,75 +178,10 @@ def common_step( # when using a real mix, pass the same mix to model and loss ref_mix_a = ref_mix ref_mix_b = ref_mix - - - - - # tracks_a = tracks[..., :input_middle_idx] # not used currently - - #print("input tracks: ", tracks[...,middle_idx:]) - #print("ref_mix: ", ref_mix_a) - - - if self.current_epoch >= self.active_compressor_epoch: - self.use_track_compressor = True - - if self.current_epoch >= self.active_fx_bus_epoch: - self.use_fx_bus = True - - if self.current_epoch >= self.active_master_bus_epoch: - self.use_master_bus = True - - bs, num_tracks, seq_len = tracks.shape - - # apply random gain to input tracks - # tracks *= 10 ** ((torch.rand(bs, num_tracks, 1).type_as(tracks) * -12.0) / 20.0) - ref_track_param_dict = None - ref_fx_bus_param_dict = None - ref_master_bus_param_dict = None - - # --------- create a random mix (on GPU, if applicable) --------- - if self.generate_mix: - ( - ref_mix_tracks, - ref_mix, - ref_track_param_dict, - ref_fx_bus_param_dict, - ref_master_bus_param_dict, - ) = self.mix_fn( - tracks, - self.mix_console, - use_track_input_fader=False, # do not use track input fader for training - use_track_panner=self.use_track_panner, - use_track_eq=self.use_track_eq, - use_track_compressor=self.use_track_compressor, - use_fx_bus=self.use_fx_bus, - use_master_bus=self.use_master_bus, - use_output_fader=False, # not used because we normalize output mixes - instrument_id=instrument_id, - stereo_id=stereo_info, - instrument_number_file=self.instrument_number_lookup, - ke_dict=self.knowledge_engineering_dict, - ) - - # normalize the reference mix - ref_mix = batch_stereo_peak_normalize(ref_mix) - - if torch.isnan(ref_mix).any(): - print(ref_track_param_dict) - raise ValueError("Found nan in ref_mix") - - ref_mix_a = ref_mix[..., :middle_idx] # this is passed to the model - ref_mix_b = ref_mix[..., middle_idx:] # this is used for loss computation - else: - # when using a real mix, pass the same mix to model and loss - ref_mix_a = ref_mix - ref_mix_b = ref_mix # tracks_a = tracks[..., :input_middle_idx] # not used currently tracks_b = tracks[..., middle_idx:] # this is passed to the model - # ---- run model with tracks from section A using reference mix from section B ---- ( pred_track_params, @@ -292,26 +213,15 @@ def common_step( # normalize the predicted mix before computing the loss # pred_mix_b = batch_stereo_peak_normalize(pred_mix_b) - if ref_track_param_dict is None: ref_track_param_dict = pred_track_param_dict ref_fx_bus_param_dict = pred_fx_bus_param_dict ref_master_bus_param_dict = pred_master_bus_param_dict # ---------------------------- compute and log loss ------------------------------ - - - #print("pred_mix: ", pred_mix_b) - # if pred_mix_b.sum() == 0: - - #print("pred_track_params: ", pred_track_params) - #print("pred_fx_bus_params: ", pred_fx_bus_params) - #print("pred_master_bus_params: ", pred_master_bus_params) - #print("ref_mix: ",ref_mix_b) - loss = 0 - #if parameter_loss is being used to train model, no need to generate mix + # if parameter_loss is being used to train model, no need to generate mix if self.use_param_loss: track_param_loss = self.loss(pred_track_params, ref_mix_params) loss += track_param_loss @@ -319,10 +229,11 @@ def common_step( fx_bus_param_loss = self.loss(pred_fx_bus_params, ref_fx_bus_params) loss += fx_bus_param_loss if self.use_master_bus: - master_bus_param_loss = self.loss(pred_master_bus_params, ref_master_bus_params) + master_bus_param_loss = self.loss( + pred_master_bus_params, ref_master_bus_params + ) loss += master_bus_param_loss - # ---------------------------- compute and log loss ------------------------------ loss = 0 @@ -335,20 +246,18 @@ def common_step( else: loss += mix_loss - if type(mix_loss) == dict: for key, value in mix_loss.items(): self.log( ("train" if train else "val") + "/" + key, - value, + value.mean(), on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True, ) - #print(loss) - + # print(loss) # log the losses self.log( @@ -361,7 +270,6 @@ def common_step( sync_dist=True, ) - # sisdr_error = -self.sisdr(pred_mix_b, ref_mix_b) # log the SI-SDR error # self.log( @@ -401,13 +309,13 @@ def common_step( "ref_master_bus_param_dict": ref_master_bus_param_dict, "pred_master_bus_param_dict": pred_master_bus_param_dict, } - + return loss, data_dict def training_step(self, batch, batch_idx, optimizer_idx=0): loss, data_dict = self.common_step(batch, batch_idx, train=True) - #print(loss) + # print(loss) return loss def validation_step(self, batch, batch_idx): @@ -434,7 +342,7 @@ def configure_optimizers(self): ], ) else: - #print(optimizer) + # print(optimizer) return optimizer lr_schedulers = {"scheduler": scheduler, "interval": "epoch", "frequency": 1} From 7056cbca055b8256ec3711b1080b616fcb5eff06 Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Wed, 13 Dec 2023 10:05:02 +0000 Subject: [PATCH 06/15] reset my log directory and add back the validation examples --- configs/config_cjs.yaml | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/configs/config_cjs.yaml b/configs/config_cjs.yaml index d88609a..f4c2a1c 100644 --- a/configs/config_cjs.yaml +++ b/configs/config_cjs.yaml @@ -6,36 +6,34 @@ trainer: class_path: pytorch_lightning.loggers.WandbLogger init_args: project: DiffMST - save_dir: /import/c4dm-datasets-ext/diffmst_logs_soum + save_dir: /import/c4dm-datasets-ext/Diff-MST enable_checkpointing: true callbacks: - class_path: mst.callbacks.audio.LogAudioCallback - class_path: pytorch_lightning.callbacks.ModelSummary init_args: max_depth: 2 - #- class_path: mst.callbacks.mix.LogReferenceMix - # init_args: - # root_dirs: - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/Kat Wright_By My Side/ - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/BenFlowers_Ecstasy_Full - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song3/Titanium_HauntedAge_Full/ - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/Soren_ALittleLate_Full - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/MR0903_Moosmusic_Full - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song3/SaturnSyndicate_CatchTheWave_Full - # ref_mixes: - # - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/The Dip - Paddle To The Stars (Lyric Video).wav - # - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Feel it all Around by Washed Out (Portlandia Theme).wav - # - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Architects - Doomsday.wav - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/ref/Harry Styles - Late Night Talking (Official Video).wav - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/ref/Justin Timberlake - Can't Stop The Feeling! [Lyrics].wav - # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/ref/Taylor Swift - Shake It Off.wav + - class_path: mst.callbacks.mix.LogReferenceMix + init_args: + root_dirs: + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/Kat Wright_By My Side/ + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/BenFlowers_Ecstasy_Full + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song3/Titanium_HauntedAge_Full/ + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/Soren_ALittleLate_Full + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/MR0903_Moosmusic_Full + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song3/SaturnSyndicate_CatchTheWave_Full + ref_mixes: + - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/The Dip - Paddle To The Stars (Lyric Video).wav + - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Feel it all Around by Washed Out (Portlandia Theme).wav + - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Architects - Doomsday.wav + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/ref/Harry Styles - Late Night Talking (Official Video).wav + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/ref/Justin Timberlake - Can't Stop The Feeling! [Lyrics].wav + - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/ref/Taylor Swift - Shake It Off.wav default_root_dir: null gradient_clip_val: 10.0 devices: 1 check_val_every_n_epoch: 1 - max_epochs: 800 - log_every_n_steps: 50 accelerator: gpu strategy: ddp_find_unused_parameters_true From 91963cc6978be522e0a09b261b9e81d5d17a1376 Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:43:02 +0000 Subject: [PATCH 07/15] adding tests --- tests/test_embed.py | 10 ++++++++++ tests/test_quality_loss.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_remix.py | 24 ++++++++++-------------- 3 files changed, 54 insertions(+), 14 deletions(-) create mode 100644 tests/test_embed.py create mode 100644 tests/test_quality_loss.py diff --git a/tests/test_embed.py b/tests/test_embed.py new file mode 100644 index 0000000..c8115f5 --- /dev/null +++ b/tests/test_embed.py @@ -0,0 +1,10 @@ +import torch + +from mst.modules import SpectrogramEncoder + +if __name__ == "__main__": + encoder = SpectrogramEncoder(embed_dim=1024, l2_norm=True) + + mix = torch.randn(4, 2, 262144) + z = encoder(mix) + print(z.shape) diff --git a/tests/test_quality_loss.py b/tests/test_quality_loss.py new file mode 100644 index 0000000..50a9b9f --- /dev/null +++ b/tests/test_quality_loss.py @@ -0,0 +1,34 @@ +import torch +import torchaudio + +from mst.loss import QualityLoss + +sample_rate = 44100 + +# loss = AudioFeatureLoss(weights, sample_rate, stem_separation=False) + +ckpt_path = "/import/c4dm-datasets-ext/Diff-MST/DiffMST-Param/0ymfi1pp/checkpoints/epoch=5-step=10842.ckpt" +loss = QualityLoss() + +# test with audio examples +input, _ = torchaudio.load( + "outputs/Kat Wright_By My Side-->The Dip - Paddle To The Stars (Lyric Video)/mono_mix_section.wav" +) +target, _ = torchaudio.load( + "outputs/Kat Wright_By My Side-->The Dip - Paddle To The Stars (Lyric Video)/ref_mix_section.wav" +) + + +input = input.unsqueeze(0) +target = target.unsqueeze(0) + +# input = input.repeat(4, 1, 1) +# target = target.repeat(4, 1, 1) + +# input[0, ...] = 0.0001 * torch.randn_like(input[0, ...]) +# target[0, ...] = 0.0001 * torch.randn_like(input[0, ...]) + +target_loss_val = loss(target) +input_loss_val = loss(input) + +print(f"target loss: {target_loss_val.mean()} input loss: {input_loss_val.mean()}") diff --git a/tests/test_remix.py b/tests/test_remix.py index d112b0a..ed4511b 100644 --- a/tests/test_remix.py +++ b/tests/test_remix.py @@ -1,25 +1,21 @@ import torch +import torchaudio from mst.modules import Remixer, AdvancedMixConsole from mst.dataloader import MixDataset if __name__ == "__main__": root_dir = "/import/c4dm-datasets-ext/mtg-jamendo" mix_dataset = MixDataset(root_dir, length=262144) - mix_dataloader = torch.utils.data.DataLoader(mix_dataset, batch_size=4) - - mix_console = AdvancedMixConsole(44100) - remixer = Remixer(44100) - - remixer.cuda() + mix_dataloader = torch.utils.data.DataLoader( + mix_dataset, batch_size=4, num_workers=4 + ) for batch_idx, batch in enumerate(mix_dataloader): - mix = batch - - mix = mix.cuda() + mix, label = batch - # create remix - remix, track_params, fx_bus_params, master_bus_params = remixer( - mix, mix_console - ) + for i in range(mix.shape[0]): + torchaudio.save(f"debug/{batch_idx}-{i}-{label[i]}.wav", mix[i], 44100) - print(batch_idx, mix.abs().max(), remix.abs().max()) + print(mix.shape, label) + if batch_idx > 10: + break From 1e3f18b215ab33c5ea0dbace2ebe9271d8e68741 Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:43:37 +0000 Subject: [PATCH 08/15] adding quality system and org of configs --- configs/config_cjs.yaml | 38 +++--- configs/config_quality.yaml | 29 ++++ configs/data/jamendo.yaml | 4 +- configs/models/gain+eq+comp-feat+quality.yaml | 68 ++++++++++ configs/models/gain+eq+comp-feat.yaml | 12 +- configs/models/gain+eq+comp-quality.yaml | 58 ++++++++ configs/models/quality-estim.yaml | 11 ++ mst/quality_system.py | 126 ++++++++++++++++++ 8 files changed, 322 insertions(+), 24 deletions(-) create mode 100644 configs/config_quality.yaml create mode 100644 configs/models/gain+eq+comp-feat+quality.yaml create mode 100644 configs/models/gain+eq+comp-quality.yaml create mode 100644 configs/models/quality-estim.yaml create mode 100644 mst/quality_system.py diff --git a/configs/config_cjs.yaml b/configs/config_cjs.yaml index f4c2a1c..713d308 100644 --- a/configs/config_cjs.yaml +++ b/configs/config_cjs.yaml @@ -13,26 +13,26 @@ trainer: - class_path: pytorch_lightning.callbacks.ModelSummary init_args: max_depth: 2 - - class_path: mst.callbacks.mix.LogReferenceMix - init_args: - root_dirs: - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/Kat Wright_By My Side/ - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/BenFlowers_Ecstasy_Full - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song3/Titanium_HauntedAge_Full/ - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/Soren_ALittleLate_Full - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/MR0903_Moosmusic_Full - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song3/SaturnSyndicate_CatchTheWave_Full - ref_mixes: - - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/The Dip - Paddle To The Stars (Lyric Video).wav - - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Feel it all Around by Washed Out (Portlandia Theme).wav - - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Architects - Doomsday.wav - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/ref/Harry Styles - Late Night Talking (Official Video).wav - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/ref/Justin Timberlake - Can't Stop The Feeling! [Lyrics].wav - - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/ref/Taylor Swift - Shake It Off.wav + #- class_path: mst.callbacks.mix.LogReferenceMix + # init_args: + # root_dirs: + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/Kat Wright_By My Side/ + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/BenFlowers_Ecstasy_Full + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song3/Titanium_HauntedAge_Full/ + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/Soren_ALittleLate_Full + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/MR0903_Moosmusic_Full + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song3/SaturnSyndicate_CatchTheWave_Full + # ref_mixes: + # - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/The Dip - Paddle To The Stars (Lyric Video).wav + # - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Feel it all Around by Washed Out (Portlandia Theme).wav + # - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Architects - Doomsday.wav + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song1/ref/Harry Styles - Late Night Talking (Official Video).wav + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/ref/Justin Timberlake - Can't Stop The Feeling! [Lyrics].wav + # - /import/c4dm-datasets-ext/diffmst_validation/validation set/song2/ref/Taylor Swift - Shake It Off.wav default_root_dir: null gradient_clip_val: 10.0 - devices: 1 - check_val_every_n_epoch: 1 + devices: 2 + check_val_every_n_epoch: 5 max_epochs: 800 log_every_n_steps: 50 accelerator: gpu @@ -42,6 +42,6 @@ trainer: enable_model_summary: true num_sanity_val_steps: 2 benchmark: true - accumulate_grad_batches: 1 + accumulate_grad_batches: 2 reload_dataloaders_every_n_epochs: 1 diff --git a/configs/config_quality.yaml b/configs/config_quality.yaml new file mode 100644 index 0000000..f3bd493 --- /dev/null +++ b/configs/config_quality.yaml @@ -0,0 +1,29 @@ +seed_everything: 42 +#ckpt_path: /import/c4dm-datasets-ext/Diff-MST/DiffMST/4bjbp29c/checkpoints/epoch=118-step=148750.ckpt + +trainer: + logger: + class_path: pytorch_lightning.loggers.WandbLogger + init_args: + project: DiffMST-Quality + save_dir: /import/c4dm-datasets-ext/Diff-MST + enable_checkpointing: true + callbacks: + - class_path: pytorch_lightning.callbacks.ModelSummary + init_args: + max_depth: 2 + default_root_dir: null + gradient_clip_val: 10.0 + devices: 1 + check_val_every_n_epoch: 1 + max_epochs: 500 + log_every_n_steps: 50 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + sync_batchnorm: true + precision: 32 + enable_model_summary: true + num_sanity_val_steps: 2 + benchmark: true + accumulate_grad_batches: 1 + diff --git a/configs/data/jamendo.yaml b/configs/data/jamendo.yaml index a99273b..d675cff 100644 --- a/configs/data/jamendo.yaml +++ b/configs/data/jamendo.yaml @@ -3,5 +3,5 @@ data: init_args: root_dir: /import/c4dm-datasets-ext/mtg-jamendo length: 262144 - batch_size: 4 - num_workers: 4 \ No newline at end of file + batch_size: 8 + num_workers: 8 \ No newline at end of file diff --git a/configs/models/gain+eq+comp-feat+quality.yaml b/configs/models/gain+eq+comp-feat+quality.yaml new file mode 100644 index 0000000..358865e --- /dev/null +++ b/configs/models/gain+eq+comp-feat+quality.yaml @@ -0,0 +1,68 @@ +model: + class_path: mst.system.System + init_args: + generate_mix: false + active_eq_epoch: 0 + active_compressor_epoch: 0 + active_fx_bus_epoch: 1000 + active_master_bus_epoch: 0 + mix_fn: mst.mixing.naive_random_mix + mix_console: + class_path: mst.modules.AdvancedMixConsole + init_args: + sample_rate: 44100 + input_min_gain_db: -48.0 + input_max_gain_db: 48.0 + output_min_gain_db: -48.0 + output_max_gain_db: 48.0 + eq_min_gain_db: -12.0 + eq_max_gain_db: 12.0 + min_pan: 0.0 + max_pan: 1.0 + model: + class_path: mst.modules.MixStyleTransferModel + init_args: + track_encoder: + class_path: mst.modules.SpectrogramEncoder + init_args: + n_inputs: 1 + embed_dim: 256 + n_fft: 2048 + hop_length: 512 + input_batchnorm: false + encoder_batchnorm: false + model_size: small + mix_encoder: + class_path: mst.modules.SpectrogramEncoder + init_args: + n_inputs: 1 + embed_dim: 256 + n_fft: 2048 + hop_length: 512 + input_batchnorm: false + encoder_batchnorm: false + model_size: small + controller: + class_path: mst.modules.TransformerController + init_args: + embed_dim: 256 + num_track_control_params: 27 + num_fx_bus_control_params: 25 + num_master_bus_control_params: 26 + num_layers: 12 + nhead: 8 + + loss: + class_path: mst.loss.FeatureAndQualityLoss + init_args: + sample_rate: 44100 + stem_separation: false + use_clap: false + weights: + - 0.1 # rms + - 0.001 # crest factor + - 1.0 # stereo width + - 1.0 # stereo imbalance + - 0.1 # bark spectrum + quality_ckpt_path: /import/c4dm-datasets-ext/Diff-MST/DiffMST-Quality/q60vbm8l/checkpoints/epoch=499-step=903500.ckpt + quality_weight: 0.001 diff --git a/configs/models/gain+eq+comp-feat.yaml b/configs/models/gain+eq+comp-feat.yaml index 7adf2f1..96b8e58 100644 --- a/configs/models/gain+eq+comp-feat.yaml +++ b/configs/models/gain+eq+comp-feat.yaml @@ -25,21 +25,27 @@ model: track_encoder: class_path: mst.modules.SpectrogramEncoder init_args: - embed_dim: 512 + n_inputs: 1 + embed_dim: 256 n_fft: 2048 hop_length: 512 input_batchnorm: false + encoder_batchnorm: false + model_size: small mix_encoder: class_path: mst.modules.SpectrogramEncoder init_args: - embed_dim: 512 + n_inputs: 1 + embed_dim: 256 n_fft: 2048 hop_length: 512 input_batchnorm: false + encoder_batchnorm: false + model_size: small controller: class_path: mst.modules.TransformerController init_args: - embed_dim: 512 + embed_dim: 256 num_track_control_params: 27 num_fx_bus_control_params: 25 num_master_bus_control_params: 26 diff --git a/configs/models/gain+eq+comp-quality.yaml b/configs/models/gain+eq+comp-quality.yaml new file mode 100644 index 0000000..4b228ee --- /dev/null +++ b/configs/models/gain+eq+comp-quality.yaml @@ -0,0 +1,58 @@ +model: + class_path: mst.system.System + init_args: + generate_mix: false + active_eq_epoch: 0 + active_compressor_epoch: 0 + active_fx_bus_epoch: 1000 + active_master_bus_epoch: 0 + mix_fn: mst.mixing.naive_random_mix + mix_console: + class_path: mst.modules.AdvancedMixConsole + init_args: + sample_rate: 44100 + input_min_gain_db: -48.0 + input_max_gain_db: 48.0 + output_min_gain_db: -48.0 + output_max_gain_db: 48.0 + eq_min_gain_db: -12.0 + eq_max_gain_db: 12.0 + min_pan: 0.0 + max_pan: 1.0 + model: + class_path: mst.modules.MixStyleTransferModel + init_args: + track_encoder: + class_path: mst.modules.SpectrogramEncoder + init_args: + n_inputs: 1 + embed_dim: 256 + n_fft: 2048 + hop_length: 512 + input_batchnorm: false + encoder_batchnorm: false + model_size: small + mix_encoder: + class_path: mst.modules.SpectrogramEncoder + init_args: + n_inputs: 1 + embed_dim: 256 + n_fft: 2048 + hop_length: 512 + input_batchnorm: false + encoder_batchnorm: false + model_size: small + controller: + class_path: mst.modules.TransformerController + init_args: + embed_dim: 256 + num_track_control_params: 27 + num_fx_bus_control_params: 25 + num_master_bus_control_params: 26 + num_layers: 12 + nhead: 8 + + loss: + class_path: mst.loss.QualityLoss + init_args: + ckpt_path: /import/c4dm-datasets-ext/Diff-MST/DiffMST-Quality/q60vbm8l/checkpoints/epoch=351-step=636064.ckpt diff --git a/configs/models/quality-estim.yaml b/configs/models/quality-estim.yaml new file mode 100644 index 0000000..a3ca1cb --- /dev/null +++ b/configs/models/quality-estim.yaml @@ -0,0 +1,11 @@ +model: + class_path: mst.quality_system.QualityEstimationSystem + init_args: + encoder: + class_path: mst.modules.SpectrogramEncoder + init_args: + embed_dim: 512 + n_inputs: 1 + l2_norm: true + input_batchnorm: false + encoder_batchnorm: false \ No newline at end of file diff --git a/mst/quality_system.py b/mst/quality_system.py new file mode 100644 index 0000000..92a437b --- /dev/null +++ b/mst/quality_system.py @@ -0,0 +1,126 @@ +import os +import torch +import itertools +import pytorch_lightning as pl + +from typing import Callable +from mst.utils import batch_stereo_peak_normalize + +import warnings + +warnings.filterwarnings( + "ignore" +) # fix this later to catch warnings about reading mp3 files + + +class QualityEstimationSystem(pl.LightningModule): + def __init__( + self, + encoder: torch.nn.Module, + schedule: str = "step", + lr: float = 3e-4, + max_epochs: int = 500, + **kwargs, + ) -> None: + super().__init__() + self.encoder = encoder + self.projector = torch.nn.Sequential( + torch.nn.Linear(encoder.embed_dim, 2 * encoder.embed_dim), + torch.nn.ReLU(), + torch.nn.Linear(2 * encoder.embed_dim, 1), + ) + self.save_hyperparameters(ignore=["encoder"]) + + def forward( + self, + mix: torch.Tensor, + ) -> torch.Tensor: + # could consider masking different parts of input and output + # so the model cannot rely on perfectly aligned inputs + + z = self.encoder(mix) + + # project to parameter space + pred = self.projector(z) + + return pred + + def common_step( + self, + batch: tuple, + batch_idx: int, + optimizer_idx: int = 0, + train: bool = False, + ): + """Model step used for validation and training. + Args: + batch (Tuple[Tensor, Tensor]): Batch items containing rmix, stems and orig mix + batch_idx (int): Index of the batch within the current epoch. + optimizer_idx (int): Index of the optimizer, this step is called once for each optimizer. + train (bool): Wether step is called during training (True) or validation (False). + """ + mix, label = batch + + pred_label = self(mix).squeeze() + + loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_label, label) + + # log the losses + self.log( + ("train" if train else "val") + "/loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + # compute accuracy + pred_label = torch.sigmoid(pred_label) + pred_label = torch.round(pred_label) + acc = torch.sum(pred_label == label) / label.numel() + self.log( + ("train" if train else "val") + "/acc", + acc, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch, batch_idx, optimizer_idx=0): + loss = self.common_step(batch, batch_idx, train=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.common_step(batch, batch_idx, train=False) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + itertools.chain(self.encoder.parameters(), self.projector.parameters()), + lr=self.hparams.lr, + betas=(0.9, 0.999), + ) + + if self.hparams.schedule == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=self.hparams.max_epochs + ) + elif self.hparams.schedule == "step": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + [ + int(self.hparams.max_epochs * 0.85), + int(self.hparams.max_epochs * 0.95), + ], + ) + else: + return optimizer + lr_schedulers = {"scheduler": scheduler, "interval": "epoch", "frequency": 1} + + return [optimizer], lr_schedulers From 8db2b0ec4704ec82fa71f7cbb4509b7e90daf793 Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:43:49 +0000 Subject: [PATCH 09/15] improved online mixing script --- scripts/online.py | 31 +++++++++++++++++++------------ scripts/online.sh | 5 +++-- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/scripts/online.py b/scripts/online.py index 72b3ee6..177ffa2 100644 --- a/scripts/online.py +++ b/scripts/online.py @@ -20,6 +20,8 @@ def optimize( init_scale: float = 0.001, lr: float = 1e-3, n_iters: int = 100, + use_fx_bus: bool = False, + use_master_bus: bool = False, ): """Create a mix from the tracks that is as close as possible to the reference mixture. @@ -29,6 +31,8 @@ def optimize( mix_console (torch.nn.Module): Mix console instance. (e.g. AdvancedMixConsole) loss_function (torch.nn.Module): Loss function instance. (e.g. AudioFeatureLoss) n_iters (int): Number of iterations for the optimization. + use_fx_bus (bool): Whether to use the fx bus in the mix console. + use_master_bus (bool): Whether to use the master bus in the mix console. Returns: torch.Tensor: Tensor of shape (2, n_samples) that is as close as possible to the reference mixture. @@ -78,7 +82,8 @@ def optimize( torch.sigmoid(track_params), torch.sigmoid(fx_bus_params), torch.sigmoid(master_bus_params), - use_fx_bus=False, + use_fx_bus=use_fx_bus, + use_master_bus=use_master_bus, ) mix = result[1] track_param_dict = result[2] @@ -273,12 +278,11 @@ def optimize( mix_console = AdvancedMixConsole(args.sample_rate) weights = [ - 0.1, # rms - 0.001, # crest factor - 1.0, # stereo width - 1.0, # stereo imbalance - 1.00, # bark spectrum - 100.0, # clap + 10.0, # rms + 0.01, # crest factor + 10.0, # stereo width + 100.0, # stereo imbalance + 0.01, # bark spectrum ] if args.loss == "feat": @@ -286,6 +290,7 @@ def optimize( weights, args.sample_rate, stem_separation=args.stem_separation, + use_clap=False, ) elif args.loss == "clap": loss_function = StereoCLAPLoss() @@ -368,13 +373,15 @@ def optimize( plt.savefig(os.path.join(output_dir, "plots", "bark_specta.png")) plt.close("all") + fig, ax = plt.subplots(1, 1) + for idx, (loss_name, loss_vals) in enumerate(loss_history.items()): - fig, ax = plt.subplots(1, 1) ax.plot(loss_vals, label=loss_name) - ax.set_xlabel("Iteration") - ax.set_ylabel(f"{loss_name}") - plt.savefig(os.path.join(output_dir, "plots", f"{loss_name}.png")) - plt.close("all") + + plt.legend() + ax.set_xlabel("Iteration") + plt.savefig(os.path.join(output_dir, "plots", f"{loss_name}.png")) + plt.close("all") # -------------------------- save results -------------------------- # # save mix diff --git a/scripts/online.sh b/scripts/online.sh index 3e44053..95d5d52 100755 --- a/scripts/online.sh +++ b/scripts/online.sh @@ -1,8 +1,9 @@ -CUDA_VISIBLE_DEVICES=4 python scripts/online.py \ +CUDA_VISIBLE_DEVICES=5 python scripts/online.py \ --track_dir "/import/c4dm-datasets-ext/test-multitracks/Kat Wright_By My Side" \ --ref_mix "/import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/The Dip - Paddle To The Stars (Lyric Video).wav" \ --use_gpu \ ---n_iters 1000 \ +--n_iters 10000 \ --loss "feat" \ +--lr 0.001 \ #--stem_separation \ \ No newline at end of file From bce4eaf27fe1b66ff08d65d589774f369d659712 Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:45:10 +0000 Subject: [PATCH 10/15] add batch norm to audio before loss (not sure this helps) --- mst/system.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mst/system.py b/mst/system.py index e814cf4..762a913 100644 --- a/mst/system.py +++ b/mst/system.py @@ -211,7 +211,8 @@ def common_step( ) # normalize the predicted mix before computing the loss - # pred_mix_b = batch_stereo_peak_normalize(pred_mix_b) + pred_mix_b = batch_stereo_peak_normalize(pred_mix_b) + ref_mix_b = batch_stereo_peak_normalize(ref_mix_b) if ref_track_param_dict is None: ref_track_param_dict = pred_track_param_dict From 566816087581cf802e4c47f2aac0fa3f87c5075c Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:45:28 +0000 Subject: [PATCH 11/15] simplying panns --- mst/panns.py | 118 +++++++++++++++++++++++++++++---------------------- 1 file changed, 67 insertions(+), 51 deletions(-) diff --git a/mst/panns.py b/mst/panns.py index 723e79e..bb20c96 100644 --- a/mst/panns.py +++ b/mst/panns.py @@ -125,41 +125,76 @@ def __init__( num_classes: int, n_inputs: int = 1, use_batchnorm: bool = True, + model_size: str = "large", ): super(Cnn14, self).__init__() - self.conv_block1 = ConvBlock( - in_channels=n_inputs, - out_channels=64, - use_batchnorm=use_batchnorm, - ) - self.conv_block2 = ConvBlock( - in_channels=64, - out_channels=128, - use_batchnorm=use_batchnorm, - ) - self.conv_block3 = ConvBlock( - in_channels=128, - out_channels=256, - use_batchnorm=use_batchnorm, - ) - self.conv_block4 = ConvBlock( - in_channels=256, - out_channels=512, - use_batchnorm=use_batchnorm, - ) - self.conv_block5 = ConvBlock( - in_channels=512, - out_channels=1024, - use_batchnorm=use_batchnorm, - ) - self.conv_block6 = ConvBlock( - in_channels=1024, - out_channels=2048, - use_batchnorm=use_batchnorm, - ) + if model_size == "large": + self.conv_block1 = ConvBlock( + in_channels=n_inputs, + out_channels=64, + use_batchnorm=use_batchnorm, + ) + self.conv_block2 = ConvBlock( + in_channels=64, + out_channels=128, + use_batchnorm=use_batchnorm, + ) + self.conv_block3 = ConvBlock( + in_channels=128, + out_channels=256, + use_batchnorm=use_batchnorm, + ) + self.conv_block4 = ConvBlock( + in_channels=256, + out_channels=512, + use_batchnorm=use_batchnorm, + ) + self.conv_block5 = ConvBlock( + in_channels=512, + out_channels=1024, + use_batchnorm=use_batchnorm, + ) + self.conv_block6 = ConvBlock( + in_channels=1024, + out_channels=2048, + use_batchnorm=use_batchnorm, + ) + out_channels = 2048 + elif model_size == "small": + self.conv_block1 = ConvBlock( + in_channels=n_inputs, + out_channels=64, + use_batchnorm=use_batchnorm, + ) + self.conv_block2 = ConvBlock( + in_channels=64, + out_channels=128, + use_batchnorm=use_batchnorm, + ) + self.conv_block3 = ConvBlock( + in_channels=128, + out_channels=256, + use_batchnorm=use_batchnorm, + ) + self.conv_block4 = ConvBlock( + in_channels=256, + out_channels=512, + use_batchnorm=use_batchnorm, + ) + self.conv_block5 = ConvBlock5x5( + in_channels=512, + out_channels=512, + ) + self.conv_block6 = ConvBlock5x5( + in_channels=512, + out_channels=512, + ) + out_channels = 512 + else: + raise Exception(f"Invalid model_size: {model_size}") - self.fc = nn.Linear(2048, num_classes, bias=True) + self.fc = nn.Linear(out_channels, num_classes, bias=True) self.init_weight() def init_weight(self): @@ -172,33 +207,14 @@ def forward(self, x: torch.Tensor): """ batch_size, chs, bins, frames = x.size() - # x = x.view(batch_size, -1) - # x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) - # x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - # x = x.transpose(1, 3) - # x = self.bn0(x) - # x = x.transpose(1, 3) - # if self.training: - # x = self.spec_augmenter(x) - x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") - # x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(4, 4), pool_type="avg") - # x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(4, 2), pool_type="avg") - # x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(4, 2), pool_type="avg") - # x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block5(x, pool_size=(4, 2), pool_type="avg") - # x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block6(x, pool_size=(2, 2), pool_type="avg") - # x = F.dropout(x, p=0.2, training=self.training) x = torch.mean(x, dim=2) # mean across stft bins - - (x1, _) = torch.max(x, dim=2) - x2 = torch.mean(x, dim=2) - x = x1 + x2 - # x = F.dropout(x, p=0.5, training=self.training) + x = x.permute(0, 2, 1) x_out = self.fc(x) clipwise_output = x_out From a2ce1499c56ca9adeda4e2f6052a5fc60d3ff4ad Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:45:56 +0000 Subject: [PATCH 12/15] adding dataloader for quality net --- mst/dataloader.py | 74 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/mst/dataloader.py b/mst/dataloader.py index d31a41c..4e1fd75 100644 --- a/mst/dataloader.py +++ b/mst/dataloader.py @@ -5,6 +5,7 @@ import yaml import random import itertools +import pedalboard import torchaudio import numpy as np import pyloudnorm as pyln @@ -22,14 +23,15 @@ def __init__(self, root_dir: str, length: int = 524288): self.length = length self.mix_filepaths = glob.glob( - os.path.join(root_dir, "**", "*.wav"), recursive=True + os.path.join(root_dir, "**", "*.mp3"), recursive=True ) # self.mix_filepaths = glob.glob( # os.path.join(root_dir, "**", "*.mp3"), recursive=True) print(f"Located {len(self.mix_filepaths)} mixes.") - self.meter = pyln.Meter(44100) + self.sample_rate = 44100 + self.meter = pyln.Meter(self.sample_rate) def __len__(self): return len(self.mix_filepaths) @@ -39,14 +41,12 @@ def __getitem__(self, _): while not valid: # get random file idx = np.random.randint(0, len(self.mix_filepaths)) - # idx = 42 # always use the same mix for debug mix_filepath = self.mix_filepaths[idx] num_frames = torchaudio.info(mix_filepath).num_frames # find random non-silent region of the mix offset = np.random.randint(0, num_frames - self.length - 1) - offset = 0 # always use the same offset mix, _ = torchaudio.load( mix_filepath, frame_offset=offset, @@ -66,13 +66,67 @@ def __getitem__(self, _): if mix_lufs_db > -48.0: valid = True - # random gain of the target mixes - target_lufs_db = np.random.randint(-48, 0) - target_lufs_db = -14.0 # always use same target - delta_lufs_db = torch.tensor([target_lufs_db - mix_lufs_db]).float() - mix = 10.0 ** (delta_lufs_db / 20.0) * mix + # now apply some random processing to the mix + if np.random.rand() > 0.5: + quality_label = False + mix = mix.numpy() # convert to numpy for pedalboard + if np.random.rand() < 0.8: # reduce stereo width + width = np.random.uniform(0.0, 0.6) + sqrt2 = np.sqrt(2) + mid = (mix[0, :] + mix[1, :]) / sqrt2 + side = (mix[0, :] - mix[1, :]) / sqrt2 + # amplify mid and side signal separately: + mid *= 2 * (1 - width) + side *= 2 * width + # covert back to stereo + left = (mid + side) / sqrt2 + right = (mid - side) / sqrt2 + # replace original mix with processed mix + mix[0, :] = left + mix[1, :] = right + if np.random.rand() < 0.3: # downmix to mono + mono_mix = mix.mean(0, keepdims=True) + mix[0, :] = mono_mix + mix[1, :] = mono_mix + if np.random.rand() < 0.3: # stereo imbalance + mix[0, :] *= np.random.uniform(0.0, 1.0) + mix[1, :] *= np.random.uniform(0.0, 1.0) + if np.random.rand() < 0.2: # apply distortion + mix = pedalboard.Distortion( + drive_db=np.random.uniform(0, 20.0) + ).process(mix, self.sample_rate) + if np.random.rand() < 0.2: # apply reverb + mix = pedalboard.Reverb( + room_size=np.random.uniform(0, 1.0), + wet_level=0.5, + dry_level=np.random.uniform(0.0, 1.0), + ).process(mix, self.sample_rate) + if np.random.rand() < 0.2: # apply compression + mix = pedalboard.Compressor( + threshold_db=np.random.uniform(-24.0, 0.0), + ratio=np.random.uniform(2.0, 10.0), + ).process(mix, self.sample_rate) + if np.random.rand() < 0.4: # apply lowpass + mix = pedalboard.LowpassFilter( + cutoff_frequency_hz=np.random.uniform(1000.0, 20000.0) + ).process(mix, self.sample_rate) + if np.random.rand() < 0.4: # apply highpass + mix = pedalboard.HighpassFilter( + cutoff_frequency_hz=np.random.uniform(20.0, 4000.0) + ).process(mix, self.sample_rate) + if np.random.rand() < 0.2: # add white noise + mix = mix + np.random.normal(0, 0.01, mix.shape).astype(np.float32) + # convert back to torch tensor + mix = torch.from_numpy(mix) + else: + quality_label = True + + # random gain of the target mixes + target_lufs_db = -14.0 # np.random.randint(-48, 0) + delta_lufs_db = torch.tensor([target_lufs_db - mix_lufs_db]).float() + mix = 10.0 ** (delta_lufs_db / 20.0) * mix - return mix + return mix, torch.tensor(quality_label).float() class MixDataModule(pl.LightningDataModule): From 186841b5cef3b45ea1437fbc61230d3a526d7f70 Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:46:17 +0000 Subject: [PATCH 13/15] adding feature and quality loss module --- mst/loss.py | 115 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 90 insertions(+), 25 deletions(-) diff --git a/mst/loss.py b/mst/loss.py index 01b7a37..e38678d 100644 --- a/mst/loss.py +++ b/mst/loss.py @@ -9,6 +9,7 @@ from mst.fx_encoder import FXencoder from mst.modules import SpectrogramEncoder +from mst.quality_system import QualityEstimationSystem def compute_mid_side(x: torch.Tensor): @@ -414,7 +415,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): class FX_encoder_loss(torch.nn.Module): - def __init__(self, distance: Callable = torch.nn.functional.mse_loss, audiofeatures = True, weights: list[float]= [1.0],): + def __init__( + self, + distance: Callable = torch.nn.functional.mse_loss, + audiofeatures=True, + weights: list[float] = [1.0], + ): super().__init__() self.distance = distance config_path = "/homes/ssv02/Diff-MST/configs/models/fx_encoder_mst.yaml" @@ -423,42 +429,42 @@ def __init__(self, distance: Callable = torch.nn.functional.mse_loss, audiofeatu self.config = yaml.safe_load(f) checkpoint_path = "/homes/ssv02/Diff-MST/data/FXencoder_ps.pt" self.ddp = True - #self.embed_distance = torch.nn.CosineEmbeddingLoss(reduction = 'mean') + # self.embed_distance = torch.nn.CosineEmbeddingLoss(reduction = 'mean') self.embed_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6) - - # load model - self.model = FXencoder(self.config["Effects_Encoder"]['default']) + # load model + self.model = FXencoder(self.config["Effects_Encoder"]["default"]) # load checkpoint checkpoint = torch.load(checkpoint_path) from collections import OrderedDict + new_state_dict = OrderedDict() for k, v in checkpoint["model"].items(): # remove `module.` if the model was trained with DDP name = k[7:] if self.ddp else k new_state_dict[name] = v - + # load params self.model.load_state_dict(new_state_dict) self.model.eval() - + # freeze all parameters in model for param in self.model.parameters(): param.requires_grad = False - + def compute_fx_embeds(x: torch.Tensor): embed = self.model(x) return embed - - #weights = [0.1,0.001,1.0,1.0,0.1,100.0] + + # weights = [0.1,0.001,1.0,1.0,0.1,100.0] self.weights = weights self.transforms = [] - + if audiofeatures: - self.audiofeatures = audiofeatures - + self.audiofeatures = audiofeatures + self.transforms = [ compute_rms, compute_crest_factor, @@ -468,8 +474,7 @@ def compute_fx_embeds(x: torch.Tensor): ] self.transforms.append(compute_fx_embeds) - - + assert len(self.transforms) == len(self.weights) def forward(self, input: torch.Tensor, target: torch.Tensor): @@ -482,28 +487,89 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): # loss = self.distance(input_embed, target_embed) # return loss - + for transform, weight in zip(self.transforms, self.weights): transform_name = "_".join(transform.__name__.split("_")[1:]) - #print(transform_name) + # print(transform_name) input_transform = transform(input) target_transform = transform(target) if transform_name == "fx_embeds": - val = 1-self.embed_similarity(input_transform, target_transform).mean().clamp(min=1e-8) - #print(val) + val = 1 - self.embed_similarity( + input_transform, target_transform + ).mean().clamp(min=1e-8) + # print(val) else: val = torch.nn.functional.mse_loss(input_transform, target_transform) - #print(val) + # print(val) losses[transform_name] = weight * val - + return losses - + + +class QualityLoss(torch.nn.Module): + def __init__( + self, + ckpt_path: str, + ) -> None: + super().__init__() + # hard-coded model configuration + encoder = SpectrogramEncoder( + embed_dim=512, + n_inputs=1, + input_batchnorm=False, + encoder_batchnorm=False, + l2_norm=True, + ) + + self.model = QualityEstimationSystem.load_from_checkpoint( + ckpt_path, encoder=encoder + ) + self.model.eval() + self.model.freeze() + + def forward(self, input: torch.Tensor, *args, **kwargs): + """Compute loss on stereo mixes using featues from quality model. + + Args: + input: (bs, 2, seq_len) + """ + logits = self.model(input) # higher is better (high quality) + return -logits + + +class FeatureAndQualityLoss(torch.nn.Module): + def __init__( + self, + weights: List[float], + sample_rate: int, + quality_ckpt_path: str, + quality_weight: float = 1.0, + stem_separation: bool = False, + use_clap: bool = False, + ): + super().__init__() + self.feature_loss = AudioFeatureLoss( + weights=weights, + sample_rate=sample_rate, + stem_separation=stem_separation, + use_clap=use_clap, + ) + self.quality_loss = QualityLoss(quality_ckpt_path) + self.quality_weight = quality_weight + + def forward(self, input: torch.Tensor, target: torch.Tensor): + feature_losses = self.feature_loss(input, target) + quality_loss = self.quality_loss(input) + feature_losses["quality"] = quality_loss * self.quality_weight + return feature_losses + + # if __name__ == "__main__": # import torchaudio # path = "/import/c4dm-datasets-ext/mtg-jamendo_wav/02/1012002.wav" - + # #input1, sr = torchaudio.load(path, channels_first = True, num_frames = 44100*10) - + # input1= torch.zeros(2,44100*10) # input2 = input1 # #input2, sr = torchaudio.load(path, channels_first = True, num_frames = 44100*10, frame_offset = 44100*10) @@ -516,4 +582,3 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): # losses = loss(input1, input2) # print(losses) # print(sum(losses.values())) - From 561a33b88529ba6a9a6187d78a7a496f7d314c9e Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:46:43 +0000 Subject: [PATCH 14/15] changes to accomadate updated panns --- mst/modules.py | 74 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/mst/modules.py b/mst/modules.py index 6c4cbbc..9b1b058 100644 --- a/mst/modules.py +++ b/mst/modules.py @@ -44,15 +44,13 @@ def forward( if self.sum_and_diff: ref_mix_mid = ref_mix.sum(dim=1) ref_mix_side = ref_mix[..., 0:1, :] - ref_mix[..., 1:2, :] + mix_mid_size = torch.stack((ref_mix_mid, ref_mix_side), dim=1) # process the reference mix - - mid_embeds = self.mix_encoder(ref_mix_mid) - side_embeds = self.mix_encoder(ref_mix_side) - mix_embeds = torch.stack((mid_embeds, side_embeds), dim=1) + mix_embeds = self.mix_encoder(mix_mid_size) else: - mix_embeds = self.mix_encoder(ref_mix.view(bs * 2, 1, -1)) - mix_embeds = mix_embeds.view(bs, 2, -1) # restore + mix_embeds = self.mix_encoder(ref_mix) + mix_embeds = mix_embeds.unsqueeze(1) # controller will predict mix parameters for each stem based on embeds track_params, fx_bus_params, master_bus_params = self.controller( @@ -664,23 +662,6 @@ def forward(self, x: torch.Tensor): return z[:, 0, :] -class SpectrogramEncoder(torch.nn.Module): - def __init__( - self, - n_inputs=1, - embed_dim: int = 1024, - encoder_batchnorm: bool = True, - ): - super().__init__() - self.n_inputs = n_inputs - self.embed_dim = embed_dim - self.encoder_batchnorm = encoder_batchnorm - self.model = TCN(n_inputs, embed_dim) - - def forward(self, x: torch.Tensor): - return self.model(x) - - class PositionalEncoding(torch.nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024): super().__init__() @@ -754,6 +735,8 @@ def __init__( hop_length: int = 512, input_batchnorm: bool = False, encoder_batchnorm: bool = True, + l2_norm: bool = False, + model_size: str = "large", ) -> None: super().__init__() self.embed_dim = embed_dim @@ -761,6 +744,7 @@ def __init__( self.n_fft = n_fft self.hop_length = hop_length self.input_batchnorm = input_batchnorm + self.l2_norm = l2_norm window_length = int(n_fft) self.register_buffer("window", torch.hann_window(window_length=window_length)) @@ -769,6 +753,14 @@ def __init__( n_inputs=n_inputs, num_classes=embed_dim, use_batchnorm=encoder_batchnorm, + model_size=model_size, + ) + + self.attention = torch.nn.MultiheadAttention( + embed_dim=embed_dim, + num_heads=8, + dropout=0.0, + batch_first=True, ) if self.input_batchnorm: @@ -806,10 +798,34 @@ def forward(self, x: torch.torch.Tensor) -> torch.torch.Tensor: if self.input_batchnorm: X = self.bn(X) + # move channels to batch dim + X = X.view(-1, 1, X.shape[-2], X.shape[-1]) + # process with CNN - embeds = self.model(X) - # print(embeds.shape) - return embeds + embeds = self.model(X) # bs x chs, embed_dim, seq_len + embeds = embeds.view(-1, chs, self.embed_dim) + + # project down to embedding dim via attention + embeds, _ = self.attention(embeds, embeds, embeds) + # bs x seq_len, chs, embed_dim + embeds = embeds.mean(dim=1) # mean across channels + + # move seq dim back + embeds = embeds.view(bs, -1, self.embed_dim) + # bs, seq_len, embed_dim + + # compute statistics (mean, std, max, min) across time + (x1, _) = torch.max(embeds, dim=1) + # (x2, _) = torch.min(embeds, dim=1) + x3 = torch.mean(embeds, dim=1) + # x4 = torch.std(embeds, dim=1) + x = x1 + x3 + + # apply l2 norm + if self.l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + + return x class TransformerController(torch.nn.Module): @@ -845,7 +861,7 @@ def __init__( self.use_master_bus = use_master_bus self.track_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim)) - self.mix_embedding = torch.nn.Parameter(torch.randn(1, 2, embed_dim)) + self.mix_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim)) self.fx_bus_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim)) self.master_bus_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim)) @@ -873,7 +889,7 @@ def forward( Args: track_embeds (torch.torch.Tensor): Embeddings for each track with shape (bs, num_tracks, embed_dim) - mix_embeds (torch.torch.Tensor): Embeddings for the reference mix with shape (bs, 2, embed_dim) + mix_embeds (torch.torch.Tensor): Embeddings for the reference mix with shape (bs, 1, embed_dim) track_padding_mask (Optional[torch.Tensor]): Mask for the track embeddings with shape (bs, num_tracks) Returns: @@ -897,7 +913,7 @@ def forward( track_padding_mask = torch.cat( ( track_padding_mask, - torch.zeros(bs, 4).bool().type_as(track_padding_mask), + torch.zeros(bs, 3).bool().type_as(track_padding_mask), ), dim=1, ) From a4a34e5c985f3631600d38939739399e2314457c Mon Sep 17 00:00:00 2001 From: csteinmetz1 Date: Mon, 29 Jan 2024 09:47:01 +0000 Subject: [PATCH 15/15] example of how to train quality net --- README.md | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 09c3cf2..c805012 100644 --- a/README.md +++ b/README.md @@ -60,20 +60,30 @@ First update the paths in the configuration file for both the logger and the dat Then call the `main.py` script passing in the configuration file. ``` # new model configuration with audio feature loss -CUDA_VISIBLE_DEVICES=0 python main.py fit \ +CUDA_VISIBLE_DEVICES=2,4 python main.py fit \ -c configs/config_cjs.yaml \ -c configs/optimizer.yaml \ --c configs/data/medley+cambridge+jamendo-8.yaml \ +-c configs/data/medley+cambridge+jamendo-16.yaml \ -c configs/models/gain+eq+comp-feat.yaml # new model configuration with CLAP loss -CUDA_VISIBLE_DEVICES=0 python main.py fit \ +CUDA_VISIBLE_DEVICES=7 python main.py fit \ -c configs/config_cjs.yaml \ -c configs/optimizer.yaml \ --c configs/data/medley+cambridge+jamendo-8.yaml \ +-c configs/data/medley+cambridge+jamendo-4.yaml \ -c configs/models/gain+eq+comp-clap.yaml ``` +``` +CUDA_VISIBLE_DEVICES=5 python main.py fit \ +-c configs/config_quality.yaml \ +-c configs/optimizer.yaml \ +-c configs/data/jamendo.yaml \ +-c configs/models/quality-estim.yaml +``` + +``` +``` # Stability (ignore) ```