Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion config/grafp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ T_max: 400
lambda: 0.0
error_threshold: 5

weight_decay: 0 #1.0e-6 # 0 is the same as no weight decay

# stem: 'drums'
# SampleID train hyperparameters
mix_prob: 0.95
Expand All @@ -66,9 +68,10 @@ min_beats_required: 32
mix_prob: 0.95
mix_gain_range: [0.1, 0.7] #[0.05, 0.55]
min_beats_required: 32 # minimum number of beats required in a sample to be included in the dataset
tempo_ratio_range: [0.5, 2.0] #[0.75, 1.5]

# Augmentation hyperparameters
n_frames: 128 #10 #32 # depends on the spectrogram parameters (10 is for Music2latent), old nerualFP was 32, now 128 apparently
n_frames: 128 #10 #32 # depends on the spectrogram parameters (10 is for Music2latent), old nerualFP was 32, now 128
overlap: 0.875 #0.5
tr_snr: [0, 20]
val_snr: [0, 10]
Expand Down
91 changes: 43 additions & 48 deletions modules/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,11 @@ def __getitem__(self, idx):
audio_resampled = resampler(audio_mono)

clip_frames = int(self.sample_rate * self.dur)
offset_frames = int(self.sample_rate * self.offset)

segment_length = clip_frames + offset_frames

if len(audio_resampled) <= clip_frames:
# self.ignore_idx.append(idx)
if len(audio_resampled) < segment_length:
return self[idx + 1]

key = self.get_key_for_file(datapath)
Expand All @@ -449,63 +451,56 @@ def __getitem__(self, idx):
"beats": beats,
}

# For training pipeline, output a random frame of the audio
if self.train:
a_i = audio_resampled
a_j = a_i.clone()
a_i = audio_resampled

offset_mod = int(self.sample_rate * (self.offset) + clip_frames)
if len(audio_resampled) < offset_mod:
print(
"Audio too short (offset_mod > len(audio resampled)). Skipping..."
)
return self[idx + 1]
r = np.random.randint(0, len(audio_resampled) - offset_mod)
ri = np.random.randint(0, offset_mod - clip_frames)
rj = np.random.randint(0, offset_mod - clip_frames)
start_idx = np.random.randint(0, len(audio_resampled) - segment_length + 1)
a_i = a_i[start_idx : start_idx + segment_length]

# Add timestamps to metadata
metadata.update(
{"start_i": r + ri, "start_j": r + rj, "clip_length": clip_frames}
)
a_j = a_i.clone()

clip_i = a_i[r : r + offset_mod]
clip_j = a_j[r : r + offset_mod]
x_i = clip_i[ri : ri + clip_frames]
x_j = clip_j[rj : rj + clip_frames]
# Introduce offset by extracting a random dur-length segment
x_i_start = np.random.randint(0, offset_frames)
x_j_start = np.random.randint(0, offset_frames)

if x_i.abs().max() < self.silence or x_j.abs().max() < self.silence:
print("Silence detected. Skipping...")
return self[idx + 1]
x_i = a_i[x_i_start : x_i_start + clip_frames]
x_j = a_j[x_j_start : x_j_start + clip_frames]

if self.norm is not None:
norm_val = qtile_norm(audio_resampled, q=self.norm)
x_i = x_i / norm_val
x_j = x_j / norm_val
# Add timestamps to metadata
metadata.update(
{
"start_i": start_idx + x_i_start,
"start_j": start_idx + x_j_start,
"clip_length": clip_frames,
}
)

if self.transform is not None:
x_i, x_j, transform_metadata = self.transform(x_i, x_j, metadata)
if x_i.abs().max() < self.silence or x_j.abs().max() < self.silence:
print("Silence detected. Skipping...")
return self[idx + 1]

if x_i is None or x_j is None:
return self[idx + 1]
# if self.norm is not None:
# norm_val = qtile_norm(audio_resampled, q=self.norm)
# x_i = x_i / norm_val
# x_j = x_j / norm_val

# Pad or truncate to sample_rate * dur
if len(x_i) < clip_frames:
x_i = F.pad(x_i, (0, clip_frames - len(x_i)))
else:
x_i = x_i[:clip_frames]
if self.transform is not None:
x_i, x_j, transform_metadata = self.transform(x_i, x_j, metadata)

if len(x_j) < clip_frames:
x_j = F.pad(x_j, (0, clip_frames - len(x_j)))
else:
x_j = x_j[:clip_frames]
if x_i is None or x_j is None:
return self[idx + 1]

return x_i, x_j, metadata
# Pad or truncate to sample_rate * dur
if len(x_i) < clip_frames:
x_i = F.pad(x_i, (0, clip_frames - len(x_i)))
else:
x_i = x_i[:clip_frames]

# For validation / test, output consecutive (overlapping) frames
if len(x_j) < clip_frames:
x_j = F.pad(x_j, (0, clip_frames - len(x_j)))
else:
return audio_resampled, None, metadata
# return audio_resampled
x_j = x_j[:clip_frames]

return x_i, x_j, metadata

def __len__(self):
return len(self.filenames)
return len(self.filenames)
40 changes: 23 additions & 17 deletions modules/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(self, cfg, train=True, cpu=False):
self.mix_prob = float(cfg.get("mix_prob", 0.95))
self.mix_gain_range = cfg.get("mix_gain_range", [0.05, 0.5]) # Narrower range
self.mix_gain_range = [float(i) for i in self.mix_gain_range]
self.tempo_ratio_range = cfg.get("tempo_ratio_range", [0.5,2.0])

# Keep melspec transform
self.logmelspec = nn.Sequential(
Expand Down Expand Up @@ -232,10 +233,15 @@ def get_transpose_semitones(self, from_key, to_key):
to_key = (to_key - 7) % 12

# Calculate the smallest semitone difference needed
difference = (to_key - from_key) % 12
if difference > 6:
difference -= 12
return difference
# Calculate direct difference first
direct_diff = to_key - from_key

# Normalize to find shortest path
if direct_diff > 6:
direct_diff -= 12
elif direct_diff < -6:
direct_diff += 12
return direct_diff

def analyze_tempo(self, beats_data):
"""Calculate tempo and time between beats"""
Expand All @@ -261,9 +267,9 @@ def get_tempo_ratio(self, source_tempo, target_tempo):
raw_ratio = target_tempo / source_tempo

# Find the closest power of 2 multiple/divisor that keeps ratio between 0.5 and 2.0
while raw_ratio > 2.0:
while raw_ratio > self.tempo_ratio_range[1]: #1.5: #2.0:
raw_ratio /= 2.0
while raw_ratio < 0.5:
while raw_ratio < self.tempo_ratio_range[0]: #0.75: #0.5:
raw_ratio *= 2.0

return raw_ratio
Expand Down Expand Up @@ -411,26 +417,26 @@ def process_audio_batch(self, batch_audio, metadata):
# print("Offset", offset)

# Apply offset and padding/trimming to same length
target_length = len(audio)
target_length = len(other_audio)
if offset >= 0:
# Add offset zeros at the start
other_audio = np.pad(other_audio, (offset, 0))
audio = np.pad(audio, (offset, 0))
# Then trim/pad to target length
if len(other_audio) > target_length:
other_audio = other_audio[:target_length]
if len(audio) > target_length:
audio = audio[:target_length]
else:
other_audio = np.pad(
other_audio, (0, target_length - len(other_audio))
audio = np.pad(
audio, (0, target_length - len(audio))
)
elif offset < 0:
other_audio = other_audio[-offset:]
if len(other_audio) > target_length:
audio = audio[-offset:]
if len(audio) > target_length:
# If longer than target, trim the end
other_audio = other_audio[:target_length]
audio = audio[:target_length]
else:
# If shorter than target, pad the end
other_audio = np.pad(
other_audio, (0, target_length - len(other_audio))
audio = np.pad(
audio, (0, target_length - len(audio))
)

# Verify lengths match before mixing
Expand Down
15 changes: 14 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,18 @@ def main():
cfg = load_config(args.config)
writer = SummaryWriter(f'runs/{args.ckp}')

# log the configuration
print("Configuration parameters:")
for key, value in cfg.items():
print(f" {key}: {value}")

# Log all config parameters to TensorBoard
# Convert nested structures to strings for TensorBoard
config_flat = {}
for key, value in cfg.items():
config_flat[key] = str(value)
writer.add_text("Configuration", str(config_flat), 0)

additive = args.additive

if not additive:
Expand All @@ -209,6 +221,7 @@ def main():
# Hyperparameters
batch_size = cfg['bsz_train']
learning_rate = cfg['lr']
weight_decay = cfg['weight_decay']
num_epochs = override(cfg['n_epochs'], args.epochs)
model_name = args.ckp
random_seed = args.seed
Expand Down Expand Up @@ -309,7 +322,7 @@ def main():

print(count_parameters(model, args.encoder))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = cfg['T_max'], eta_min = cfg['min_lr'])
# scaler = GradScaler(enabled=True)
scaler = DummyScaler()
Expand Down