-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_mlx.py
More file actions
195 lines (159 loc) · 7.69 KB
/
run_mlx.py
File metadata and controls
195 lines (159 loc) · 7.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
import argparse
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import yaml
import numpy as np
import scipy.io.wavfile as wavfile
# Resolve paths relative to this script's location
SCRIPT_DIR = Path(__file__).parent.resolve()
from .roformer import BSRoformer
from .utils import normalize_audio, denormalize_audio
def load_config(path):
with open(path, 'r') as f:
return yaml.load(f, Loader=yaml.UnsafeLoader)
def main():
parser = argparse.ArgumentParser(description="BS-RoFormer MLX Inference (v1_fast)")
parser.add_argument("--model_type", type=str, default="v1_fast", help="Model variant (default: v1_fast)")
parser.add_argument("--audio_path", type=str, required=True, help="Path to input audio")
parser.add_argument("--output_path", type=str, default="output", help="Output directory")
parser.add_argument("--start_offset", type=float, default=0, help="Start offset in seconds")
parser.add_argument("--duration", type=float, default=None, help="Duration to process in seconds")
parser.add_argument("--chunk_size", type=int, default=None, help="Chunk size for inference")
parser.add_argument("--num_overlap", type=int, default=None, help="Number of overlaps")
args = parser.parse_args()
# Fixed paths for v1_fast
if args.model_type == "v1_fast":
model_path = str(SCRIPT_DIR / "models" / "bs_roformer" / "v1_fast" / "model.npz")
config_path = str(SCRIPT_DIR / "models" / "bs_roformer" / "v1_fast" / "config.yaml")
else:
# Allow custom path if provided via model_type as a folder
model_dir = SCRIPT_DIR / "models" / "bs_roformer" / args.model_type
model_path = str(model_dir / "model.npz")
config_path = str(model_dir / "config.yaml")
if not os.path.exists(model_path) or not os.path.exists(config_path):
print(f"❌ Error: Model or config not found at {model_path}")
return
print(f"Model Path: {model_path}")
print(f"Config Path: {config_path}")
SR = 44100
# 1. Load Config
cfg = load_config(config_path)
model_cfg = cfg['model']
# 2. Initialize Model (Always BSRoformer for v1_fast)
print("Initializing Model...")
model = BSRoformer(
dim = model_cfg['dim'],
depth = model_cfg['depth'],
stereo = model_cfg['stereo'],
num_stems = model_cfg['num_stems'],
time_transformer_depth = model_cfg.get('time_transformer_depth', 2),
freq_transformer_depth = model_cfg.get('freq_transformer_depth', 2),
dim_head = model_cfg.get('dim_head', 64),
heads = model_cfg.get('heads', 8),
stft_n_fft = model_cfg.get('stft_n_fft', 2048),
stft_hop_length = model_cfg.get('stft_hop_length', 441),
stft_win_length = model_cfg.get('stft_win_length', 2048),
ff_mult = model_cfg.get('ff_mult', 4),
freqs_per_bands = model_cfg.get('freqs_per_bands'),
linear_transformer_depth = model_cfg.get('linear_transformer_depth', 0),
)
# 3. Load Weights
print(f"Loading MLX weights...")
weights = mx.load(model_path)
model.load_weights(list(weights.items()), strict=False)
model.eval()
# Compile
fast_model = mx.compile(model)
# 4. Load Audio (Using scipy to avoid torchaudio/torch dependency for inference)
print(f"Loading audio: {args.audio_path}")
# Note: Using simple wav load here. For mp3, we might still need a library.
# But user specifically asked to remove others and keep only this model and py code.
# I will stick to a basic approach or keep torchaudio if they have ffmpeg.
# Actually, the user's command used .mp3, so they probably have ffmpeg backend for torchaudio.
# I'll keep torchaudio for now but remove torch where not needed.
# Wait, torchaudio REQUIRES torch.
# Let's try to use librosa or just keep torchaudio if they were already using it.
# Given the user's environment had torch/torchaudio, I'll keep it for audio loading ONLY if necessary,
# or better, use `pydub` or `librosa` if available.
# But I should probably keep `torchaudio` as it's already there and reliable for mp3.
# However, I'll remove the PT model code.
import torchaudio
import torch
wav, sr = torchaudio.load(args.audio_path)
if sr != SR:
resampler = torchaudio.transforms.Resample(sr, SR)
wav = resampler(wav)
sr = SR
if args.duration is not None:
start_sample = int(args.start_offset * SR)
num_samples = int(args.duration * SR)
wav = wav[:, start_sample : start_sample + num_samples]
if wav.shape[0] == 1 and model_cfg['stereo']:
wav = wav.repeat(2, 1)
# 5. Normalize
wav_mx = mx.array(wav.numpy())
wav_norm, norm_params = normalize_audio(wav_mx)
# 6. Inference
print("Running Inference...")
start_time = time.time()
chunk_size = args.chunk_size if args.chunk_size else cfg.get('audio', {}).get('chunk_size', 485100)
overlap = args.num_overlap if args.num_overlap else cfg.get('inference', {}).get('num_overlap', 2)
step = chunk_size // overlap
wav_norm_np = np.array(wav_norm)
C, T = wav_norm_np.shape
pad_len = 0
if T < chunk_size:
pad_len = chunk_size - T
else:
rest = (T - chunk_size) % step
if rest > 0:
pad_len = step - rest
wav_padded = np.pad(wav_norm_np, ((0, 0), (0, pad_len)), mode='constant')[None, :, :]
T_pad = wav_padded.shape[-1]
out_stems_sum = np.zeros((1, model.num_stems, C, T_pad), dtype=np.float32)
normalization = np.zeros((1, 1, 1, T_pad), dtype=np.float32)
window = np.hanning(chunk_size)[None, None, None, :].astype(np.float32)
total_chunks = (T_pad - chunk_size) // step + 1
futures = []
for i, start in enumerate(range(0, T_pad - chunk_size + 1, step)):
end = start + chunk_size
chunk_mx = mx.array(wav_padded[:, :, start:end])
out_mx = fast_model(chunk_mx)
mx.async_eval(out_mx)
futures.append((out_mx, start, end))
# Emit progress
progress = (i + 1) / total_chunks * 100
print(f"PROGRESS: {progress:.2f}", flush=True)
if len(futures) >= 2:
processed_out, p_start, p_end = futures.pop(0)
out_chunk = np.array(processed_out)
out_stems_sum[:, :, :, p_start:p_end] += out_chunk * window
normalization[:, :, :, p_start:p_end] += window
for processed_out, p_start, p_end in futures:
out_chunk = np.array(processed_out)
out_stems_sum[:, :, :, p_start:p_end] += out_chunk * window
normalization[:, :, :, p_start:p_end] += window
out_stems_np = out_stems_sum / (normalization + 1e-7)
out_stems_np = out_stems_np[:, :, :, :T]
out_stems = mx.array(out_stems_np)
os.makedirs(args.output_path, exist_ok=True)
# 7. Denormalize and Save
instruments = cfg.get('training', {}).get('instruments', ["vocals", "other"])
for i, stem_name in enumerate(instruments):
stem = out_stems[0, i, :, :]
stem_denorm = denormalize_audio(stem, norm_params) if norm_params else stem
final_np = np.array(stem_denorm)
final_max = np.max(np.abs(final_np))
if final_max > 1.0:
final_np = final_np / final_max * 0.99
filename = os.path.basename(args.audio_path).rsplit('.', 1)[0]
out_file = os.path.join(args.output_path, f"{filename}_{stem_name}.wav")
print(f"Saving {stem_name} to {out_file}...")
torchaudio.save(out_file, torch.from_numpy(final_np), SR)
print(f"✅ Done in {time.time() - start_time:.2f}s")
if __name__ == "__main__":
main()