From 530cd4f6b415a721a83ff3d64f0ca31a3b6a72d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=92=D0=B5=D1=80=D0=B0?= Date: Mon, 18 Nov 2024 13:49:22 +0300 Subject: [PATCH 1/2] add peak normalization for inference --- src/configs/inference.yaml | 2 +- src/configs/transforms/inference.yaml | 7 +++++++ .../transforms/instance_transforms/mel_spec.yaml | 12 +++++++++++- src/transforms/wav_augs/__init__.py | 3 ++- src/transforms/wav_augs/peak_normalize.py | 12 ++++++++++++ 5 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 src/configs/transforms/inference.yaml create mode 100644 src/transforms/wav_augs/peak_normalize.py diff --git a/src/configs/inference.yaml b/src/configs/inference.yaml index c70d374..cc4487e 100644 --- a/src/configs/inference.yaml +++ b/src/configs/inference.yaml @@ -3,7 +3,7 @@ defaults: - metrics: ss - datasets: ss_dataset - dataloader: example - - transforms: ss_full_spec + - transforms: inference - _self_ inferencer: device_tensors: ["mix_spectrogram", "s1_spectrogram", "s2_spectrogram", "mix", "s1", "s2", "s2_video", "s1_embedding", "s1_embedding", "s2_embedding"] # which tensors should be on device (ex. GPU) diff --git a/src/configs/transforms/inference.yaml b/src/configs/transforms/inference.yaml new file mode 100644 index 0000000..f3c0d14 --- /dev/null +++ b/src/configs/transforms/inference.yaml @@ -0,0 +1,7 @@ +defaults: + - instance_transforms: mel_spec + - _self_ + +batch_transforms: + train: null + inference: null \ No newline at end of file diff --git a/src/configs/transforms/instance_transforms/mel_spec.yaml b/src/configs/transforms/instance_transforms/mel_spec.yaml index 3187995..71c2856 100644 --- a/src/configs/transforms/instance_transforms/mel_spec.yaml +++ b/src/configs/transforms/instance_transforms/mel_spec.yaml @@ -25,4 +25,14 @@ train: inference: get_spectrogram: _target_: torchaudio.transforms.MelSpectrogram - sample_rate: 16000 \ No newline at end of file + sample_rate: 16000 + s1_pred: + _target_: torchvision.transforms.v2.Compose + transforms: + - _target_: src.transforms.wav_augs.PeakNormalize + p: 1.0 + s2_pred: + _target_: torchvision.transforms.v2.Compose + transforms: + - _target_: src.transforms.wav_augs.PeakNormalize + p: 1.0 \ No newline at end of file diff --git a/src/transforms/wav_augs/__init__.py b/src/transforms/wav_augs/__init__.py index d49cb81..5a2c3e1 100644 --- a/src/transforms/wav_augs/__init__.py +++ b/src/transforms/wav_augs/__init__.py @@ -1,5 +1,6 @@ from src.transforms.wav_augs.gain import Gain from src.transforms.wav_augs.noise import BackGroundNoise, ColoredNoise from src.transforms.wav_augs.shift import PitchShift, Shift +from src.transforms.wav_augs.peak_normalize import PeakNormalize -__all__ = ["Gain", "ColoredNoise", "BackGroundNoise", "Shift", "PitchShift"] +__all__ = ["Gain", "ColoredNoise", "BackGroundNoise", "Shift", "PitchShift", "PeakNormalize"] diff --git a/src/transforms/wav_augs/peak_normalize.py b/src/transforms/wav_augs/peak_normalize.py new file mode 100644 index 0000000..8154493 --- /dev/null +++ b/src/transforms/wav_augs/peak_normalize.py @@ -0,0 +1,12 @@ +import torch_audiomentations +from torch import Tensor, nn + + +class PeakNormalize(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self._aug = torch_audiomentations.PeakNormalization(*args, **kwargs) + + def __call__(self, data: Tensor): + x = data.unsqueeze(1) + return self._aug(x).squeeze(1) \ No newline at end of file From 38495c88ad12379fe95ef5748580a193d7381c08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=92=D0=B5=D1=80=D0=B0?= Date: Tue, 19 Nov 2024 12:30:26 +0300 Subject: [PATCH 2/2] add embedding calculation --- src/datasets/base_dataset.py | 45 ++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/src/datasets/base_dataset.py b/src/datasets/base_dataset.py index ea2ef61..2bee0a9 100644 --- a/src/datasets/base_dataset.py +++ b/src/datasets/base_dataset.py @@ -7,6 +7,9 @@ import torchaudio from torch.utils.data import Dataset +from src.lipreader.lipreading.dataloaders import get_preprocessing_pipelines +from src.utils.init_utils import init_lipreader + logger = logging.getLogger(__name__) @@ -92,21 +95,12 @@ def __getitem__(self, ind): s2_video_path = data_dict["s2_video_path"] s2_video = self.load_video(s2_video_path) - if data_dict["s1_embedding_path"] is not None: - s1_embedding_path = data_dict["s1_embedding_path"] - s1_embedding = self.load_object(s1_embedding_path) - - s2_embedding_path = data_dict["s2_embedding_path"] - s2_embedding = self.load_object(s2_embedding_path) - instance_data = { "mix": mix_audio, "s1": s1_audio, "s2": s2_audio, "s1_video": s1_video, "s2_video": s2_video, - "s1_embedding": s1_embedding, - "s2_embedding": s2_embedding, "audio_path": mix_wav_path, } # apply WAV augs before getting spec @@ -132,9 +126,15 @@ def __getitem__(self, ind): s2_spectrogram = self.get_spectrogram(s2_audio) instance_data.update({"s2_spectrogram": s2_spectrogram}) + s1_embedding = self.get_embedding(s1_video) + instance_data.update({"s1_embedding": s1_embedding}) + + s2_embedding = self.get_embedding(s2_video) + instance_data.update({"s2_embedding": s2_embedding}) + # exclude WAV augs for prevending double augmentations instance_data = self.preprocess_data( - instance_data, special_keys=["get_spectrogram", "mix"] + instance_data, special_keys=["get_spectrogram", "get_embedding" "mix"] ) return instance_data @@ -168,6 +168,31 @@ def get_spectrogram(self, audio): spectrogram (Tensor): spectrogram for the audio. """ return torch.log(self.instance_transforms["get_spectrogram"](audio).clamp(1e-5)) + + def get_embedding(self, video): + """ + Special instance transform to get an embedding from video. + + Args: + video (Tensor): original video. + Returns: + embedding (Tensor): embedding for the video. + """ + cfg_path = "/src/lipreader/configs/lrw_resnet18_mstcn.json" + lipreader_path = "/lrw_resnet18_mstcn_video.pth" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + lipreader = init_lipreader(cfg_path, lipreader_path).to(device) + lipreader.eval() + + preprocessing_func = get_preprocessing_pipelines(modality="video")["test"] + s_data = preprocessing_func(video) + s_data = s_data.unsqueeze(0).unsqueeze(1).to(device) + + with torch.no_grad(): + embed = lipreader(s_data, lengths=[50]).squeeze(0).transpose(0, 1) + + return embed def get_magnitude(self, audio): stft = torch.stft(