From 15d6655cd1b677d65e8c261b92c4523dc8e5db53 Mon Sep 17 00:00:00 2001 From: Peter Kiers Date: Tue, 11 Nov 2025 19:27:04 +0000 Subject: [PATCH] Added coreml provider in feature Changed overrun in ffmpeg and some fixes Added simd feature for interleaved/planar conversion Added overlap and transition power --- Cargo.toml | 6 +- cli/Cargo.toml | 2 + cli/src/cli.rs | 14 ++ cli/src/generate.rs | 96 +++++---- src/audio_ops.rs | 496 ++++++++++++++++++++++++++++++++++++++++++++ src/demucs.rs | 243 ++++++++++++++++++---- src/lib.rs | 1 + src/nistem.rs | 2 +- src/track.rs | 174 ++++++++++------ 9 files changed, 883 insertions(+), 151 deletions(-) create mode 100644 src/audio_ops.rs diff --git a/Cargo.toml b/Cargo.toml index 8ac7414..ae950aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,8 @@ optional = true [features] default = [] -cuda = ["ort/cuda" ] +cuda = ["ort/cuda"] +coreml = ["ort/coreml"] +simd = [] -benchmark = ["chrono" ] +benchmark = ["chrono"] diff --git a/cli/Cargo.toml b/cli/Cargo.toml index ce9da24..1d52df1 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -22,3 +22,5 @@ path = ".." [features] default = [] cuda = ["stemgen/cuda"] +coreml = ["stemgen/coreml"] +simd = ["stemgen/simd"] diff --git a/cli/src/cli.rs b/cli/src/cli.rs index cc1be4a..e7cbfe6 100644 --- a/cli/src/cli.rs +++ b/cli/src/cli.rs @@ -106,6 +106,20 @@ pub struct GenerateArgs { pub thread: usize, #[arg(long, default_value_t = false)] pub preserved_original_as_master: bool, + #[arg( + long, + value_name = "FLOAT", + help = "Overlap between segments (0.0-0.99). Higher values give better quality but slower processing. 0 = no overlap (fastest), 0.25 = 25%% overlap (recommended)", + default_value_t = 0.25 + )] + pub overlap: f32, + #[arg( + long, + value_name = "FLOAT", + help = "Transition power for windowing (typically 1.0). Higher values create sharper transitions.", + default_value_t = 1.0 + )] + pub transition_power: f32, } #[derive(Debug, Subcommand)] diff --git a/cli/src/generate.rs b/cli/src/generate.rs index dd697cd..9405652 100644 --- a/cli/src/generate.rs +++ b/cli/src/generate.rs @@ -3,6 +3,7 @@ use std::{ffi::OsStr, path::PathBuf}; use glob::glob; use indicatif::{ProgressBar, ProgressStyle}; use stemgen::{ + audio_ops::planar_to_interleaved, demucs::{Demucs, DemusOpts}, nistem::{self, NIStem}, track::Track, @@ -51,10 +52,11 @@ pub fn generate(ctx: &Cli, command: &GenerateArgs) -> Result> = command.files.iter().map(|raw|glob(&raw)).collect(); @@ -99,8 +101,8 @@ pub fn generate(ctx: &Cli, command: &GenerateArgs) -> Result Result-"), ); - loop { - let mut buf: Vec = vec![0f32; 343980 * 2]; - let mut original_packets = Vec::with_capacity(512); - let mut original_buffer: Vec = Vec::with_capacity(512); - - let (data, eof) = loop { - let size = input.read( - if matches!(nistem, NIStem::PreservedMaster(..)) { - Some(&mut original_packets) - } else { - None - }, - &mut buf, - )?; - read += size; - if matches!(nistem, NIStem::ConsistentStream(..)) { - original_buffer.extend(buf[..size].to_vec()); - } - if let Some(mut data) = demucs.send(&buf[..size])? { - if matches!(nistem, NIStem::ConsistentStream(..)) { - data.insert(0, original_buffer); - } - break (data, false) - } - if size != buf.len() { - let mut data = demucs.flush()?; - if matches!(nistem, NIStem::ConsistentStream(..)) { - data.insert(0, original_buffer); - } - break (data, true); - } - }; - pb.set_position(read as u64 / sample_rate); - match nistem { - NIStem::PreservedMaster(..) => nistem.write_preserved(original_packets, data)?, - NIStem::ConsistentStream(..) => nistem.write_consistent(data)?, - } + // Read entire audio file into memory + let total_samples = input.total_samples() as usize; + let mut audio_buffer: Vec = vec![0f32; total_samples]; + let mut original_packets = Vec::with_capacity(512); + let mut read_offset = 0; + + pb.set_message("Reading audio..."); + while read_offset < audio_buffer.len() { + let remaining = audio_buffer.len() - read_offset; + let chunk_size = std::cmp::min(343980 * 2, remaining); + let size = input.read( + if matches!(nistem, NIStem::PreservedMaster(..)) { + Some(&mut original_packets) + } else { + None + }, + &mut audio_buffer[read_offset..read_offset + chunk_size], + )?; - if eof { + if size == 0 { + audio_buffer.truncate(read_offset); break; } + read_offset += size; + pb.set_position((read_offset as u64 * 10) / total_samples as u64); + } + + // Process with demucs using overlap + pb.set_message("Processing stems..."); + let stems = demucs.process(&audio_buffer, |current, total| { + let progress = 10 + (current as u64 * 80) / total as u64; + pb.set_position(progress); + })?; + pb.set_position(90); + + // Write stems + pb.set_message("Writing output..."); + + // Convert stems from planar to interleaved + let stems_interleaved: Vec> = stems.into_iter() + .map(|[left, right]| planar_to_interleaved(&left, &right)) + .collect(); + + match nistem { + NIStem::PreservedMaster(..) => { + nistem.write_preserved(original_packets, stems_interleaved)?; + }, + NIStem::ConsistentStream(..) => { + // Original audio is already interleaved, prepend it to stems + let mut data_with_original = vec![audio_buffer]; + data_with_original.extend(stems_interleaved); + nistem.write_consistent(data_with_original)?; + } } pb.finish_with_message(format!("downloaded {}", filename.display())); diff --git a/src/audio_ops.rs b/src/audio_ops.rs new file mode 100644 index 0000000..97a412e --- /dev/null +++ b/src/audio_ops.rs @@ -0,0 +1,496 @@ +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +pub fn planar_to_interleaved(left: &[f32], right: &[f32]) -> Vec { + use std::arch::x86_64::*; + + assert_eq!(left.len(), right.len(), "Left and right channels must have same length"); + + let len = left.len(); + let mut interleaved = vec![0f32; len * 2]; + + // Process 8 samples at a time with AVX (256-bit) + if is_x86_feature_detected!("avx") { + let chunks = len / 8; + + unsafe { + for i in 0..chunks { + let l_offset = i * 8; + let r_offset = i * 8; + + // Load 8 floats from each channel + let l = _mm256_loadu_ps(left.as_ptr().add(l_offset)); + let r = _mm256_loadu_ps(right.as_ptr().add(r_offset)); + + // Interleave: unpack low and high parts + let l_low = _mm256_castps256_ps128(l); + let l_high = _mm256_extractf128_ps(l, 1); + let r_low = _mm256_castps256_ps128(r); + let r_high = _mm256_extractf128_ps(r, 1); + + let interleaved_0 = _mm_unpacklo_ps(l_low, r_low); + let interleaved_1 = _mm_unpackhi_ps(l_low, r_low); + let interleaved_2 = _mm_unpacklo_ps(l_high, r_high); + let interleaved_3 = _mm_unpackhi_ps(l_high, r_high); + + // Store results + let out_offset = i * 16; + _mm_storeu_ps(interleaved.as_mut_ptr().add(out_offset), interleaved_0); + _mm_storeu_ps(interleaved.as_mut_ptr().add(out_offset + 4), interleaved_1); + _mm_storeu_ps(interleaved.as_mut_ptr().add(out_offset + 8), interleaved_2); + _mm_storeu_ps(interleaved.as_mut_ptr().add(out_offset + 12), interleaved_3); + } + } + + // Handle remaining samples + for i in chunks * 8..len { + let out_idx = i * 2; + interleaved[out_idx] = left[i]; + interleaved[out_idx + 1] = right[i]; + } + } else { + return planar_to_interleaved_scalar(left, right); + } + + interleaved +} + +#[cfg(all(feature = "simd", target_arch = "aarch64"))] +pub fn planar_to_interleaved(left: &[f32], right: &[f32]) -> Vec { + use std::arch::aarch64::*; + + assert_eq!(left.len(), right.len(), "Left and right channels must have same length"); + + let len = left.len(); + let mut interleaved = vec![0f32; len * 2]; + + // Process 4 samples at a time with NEON (128-bit) + let chunks = len / 4; + + unsafe { + for i in 0..chunks { + let l_offset = i * 4; + let r_offset = i * 4; + + // Load 4 floats from each channel + let l = vld1q_f32(left.as_ptr().add(l_offset)); + let r = vld1q_f32(right.as_ptr().add(r_offset)); + + // Interleave using zip + let result = vzip1q_f32(l, r); + let result2 = vzip2q_f32(l, r); + + // Store results + let out_offset = i * 8; + vst1q_f32(interleaved.as_mut_ptr().add(out_offset), result); + vst1q_f32(interleaved.as_mut_ptr().add(out_offset + 4), result2); + } + } + + // Handle remaining samples + for i in chunks * 4..len { + let out_idx = i * 2; + interleaved[out_idx] = left[i]; + interleaved[out_idx + 1] = right[i]; + } + + interleaved +} + +#[cfg(all(feature = "simd", not(any(target_arch = "x86_64", target_arch = "aarch64"))))] +pub fn planar_to_interleaved(left: &[f32], right: &[f32]) -> Vec { + planar_to_interleaved_scalar(left, right) +} + +#[cfg(not(feature = "simd"))] +pub fn planar_to_interleaved(left: &[f32], right: &[f32]) -> Vec { + planar_to_interleaved_scalar(left, right) +} + +pub fn planar_to_interleaved_scalar(left: &[f32], right: &[f32]) -> Vec { + assert_eq!(left.len(), right.len(), "Left and right channels must have same length"); + + let mut interleaved = Vec::with_capacity(left.len() * 2); + for (l, r) in left.iter().zip(right.iter()) { + interleaved.push(*l); + interleaved.push(*r); + } + interleaved +} + +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +pub fn interleaved_to_planar(interleaved: &[f32]) -> [Vec; 2] { + use std::arch::x86_64::*; + + assert_eq!(interleaved.len() % 2, 0, "Interleaved buffer must have even length"); + + let frame_count = interleaved.len() / 2; + let mut left = vec![0f32; frame_count]; + let mut right = vec![0f32; frame_count]; + + if is_x86_feature_detected!("avx") { + let chunks = frame_count / 8; + + unsafe { + for i in 0..chunks { + let in_offset = i * 16; + + // Load 16 interleaved samples (8 frames) + let data0 = _mm_loadu_ps(interleaved.as_ptr().add(in_offset)); + let data1 = _mm_loadu_ps(interleaved.as_ptr().add(in_offset + 4)); + let data2 = _mm_loadu_ps(interleaved.as_ptr().add(in_offset + 8)); + let data3 = _mm_loadu_ps(interleaved.as_ptr().add(in_offset + 12)); + + // De-interleave using shuffle + let l0 = _mm_shuffle_ps(data0, data1, 0b10001000); // L0 L1 L2 L3 + let r0 = _mm_shuffle_ps(data0, data1, 0b11011101); // R0 R1 R2 R3 + let l1 = _mm_shuffle_ps(data2, data3, 0b10001000); // L4 L5 L6 L7 + let r1 = _mm_shuffle_ps(data2, data3, 0b11011101); // R4 R5 R6 R7 + + // Store results + let out_offset = i * 8; + _mm_storeu_ps(left.as_mut_ptr().add(out_offset), l0); + _mm_storeu_ps(left.as_mut_ptr().add(out_offset + 4), l1); + _mm_storeu_ps(right.as_mut_ptr().add(out_offset), r0); + _mm_storeu_ps(right.as_mut_ptr().add(out_offset + 4), r1); + } + } + + // Handle remaining samples + for i in chunks * 8..frame_count { + let idx = i * 2; + left[i] = interleaved[idx]; + right[i] = interleaved[idx + 1]; + } + } else { + return interleaved_to_planar_scalar(interleaved); + } + + [left, right] +} + +#[cfg(all(feature = "simd", target_arch = "aarch64"))] +pub fn interleaved_to_planar(interleaved: &[f32]) -> [Vec; 2] { + use std::arch::aarch64::*; + + assert_eq!(interleaved.len() % 2, 0, "Interleaved buffer must have even length"); + + let frame_count = interleaved.len() / 2; + let mut left = vec![0f32; frame_count]; + let mut right = vec![0f32; frame_count]; + + let chunks = frame_count / 4; + + unsafe { + for i in 0..chunks { + let in_offset = i * 8; + + // Load 8 interleaved samples (4 frames) + let data1 = vld1q_f32(interleaved.as_ptr().add(in_offset)); + let data2 = vld1q_f32(interleaved.as_ptr().add(in_offset + 4)); + + // De-interleave using unzip + let result = vuzp1q_f32(data1, data2); + let result2 = vuzp2q_f32(data1, data2); + + // Store results + let out_offset = i * 4; + vst1q_f32(left.as_mut_ptr().add(out_offset), result); + vst1q_f32(right.as_mut_ptr().add(out_offset), result2); + } + } + + // Handle remaining samples + for i in chunks * 4..frame_count { + let idx = i * 2; + left[i] = interleaved[idx]; + right[i] = interleaved[idx + 1]; + } + + [left, right] +} + +#[cfg(all(feature = "simd", not(any(target_arch = "x86_64", target_arch = "aarch64"))))] +pub fn interleaved_to_planar(interleaved: &[f32]) -> [Vec; 2] { + interleaved_to_planar_scalar(interleaved) +} + +#[cfg(not(feature = "simd"))] +pub fn interleaved_to_planar(interleaved: &[f32]) -> [Vec; 2] { + interleaved_to_planar_scalar(interleaved) +} + +pub fn interleaved_to_planar_scalar(interleaved: &[f32]) -> [Vec; 2] { + assert_eq!(interleaved.len() % 2, 0, "Interleaved buffer must have even length"); + + let frame_count = interleaved.len() / 2; + let mut left = Vec::with_capacity(frame_count); + let mut right = Vec::with_capacity(frame_count); + + for chunk in interleaved.chunks(2) { + left.push(chunk[0]); + right.push(chunk[1]); + } + + [left, right] +} + +/// Multiply buffer by weights and accumulate into output +/// output[i] += buffer[i] * weight[i] +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +pub fn weighted_accumulate(buffer: &[f32], weights: &[f32], output: &mut [f32]) { + use std::arch::x86_64::*; + + assert_eq!(buffer.len(), weights.len()); + assert_eq!(buffer.len(), output.len()); + + if is_x86_feature_detected!("avx") { + let len = buffer.len(); + let chunks = len / 8; + + unsafe { + for i in 0..chunks { + let offset = i * 8; + + let buf = _mm256_loadu_ps(buffer.as_ptr().add(offset)); + let weight = _mm256_loadu_ps(weights.as_ptr().add(offset)); + let out = _mm256_loadu_ps(output.as_ptr().add(offset)); + + let result = _mm256_add_ps(out, _mm256_mul_ps(buf, weight)); + _mm256_storeu_ps(output.as_mut_ptr().add(offset), result); + } + } + + // Handle remaining samples + for i in chunks * 8..len { + output[i] += buffer[i] * weights[i]; + } + } else { + weighted_accumulate_scalar(buffer, weights, output); + } +} + +#[cfg(all(feature = "simd", target_arch = "aarch64"))] +pub fn weighted_accumulate(buffer: &[f32], weights: &[f32], output: &mut [f32]) { + use std::arch::aarch64::*; + + assert_eq!(buffer.len(), weights.len()); + assert_eq!(buffer.len(), output.len()); + + let len = buffer.len(); + let chunks = len / 4; + + unsafe { + for i in 0..chunks { + let offset = i * 4; + + let buf = vld1q_f32(buffer.as_ptr().add(offset)); + let weight = vld1q_f32(weights.as_ptr().add(offset)); + let out = vld1q_f32(output.as_ptr().add(offset)); + + let result = vmlaq_f32(out, buf, weight); // out + buf * weight + vst1q_f32(output.as_mut_ptr().add(offset), result); + } + } + + // Handle remaining samples + for i in chunks * 4..len { + output[i] += buffer[i] * weights[i]; + } +} + +#[cfg(all(feature = "simd", not(any(target_arch = "x86_64", target_arch = "aarch64"))))] +pub fn weighted_accumulate(buffer: &[f32], weights: &[f32], output: &mut [f32]) { + weighted_accumulate_scalar(buffer, weights, output); +} + +#[cfg(not(feature = "simd"))] +pub fn weighted_accumulate(buffer: &[f32], weights: &[f32], output: &mut [f32]) { + weighted_accumulate_scalar(buffer, weights, output); +} + +pub fn weighted_accumulate_scalar(buffer: &[f32], weights: &[f32], output: &mut [f32]) { + assert_eq!(buffer.len(), weights.len()); + assert_eq!(buffer.len(), output.len()); + + for i in 0..buffer.len() { + output[i] += buffer[i] * weights[i]; + } +} + +/// Accumulate weights: output[i] += weight[i] +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +pub fn accumulate_weights(weights: &[f32], output: &mut [f32]) { + use std::arch::x86_64::*; + + assert_eq!(weights.len(), output.len()); + + if is_x86_feature_detected!("avx") { + let len = weights.len(); + let chunks = len / 8; + + unsafe { + for i in 0..chunks { + let offset = i * 8; + + let weight = _mm256_loadu_ps(weights.as_ptr().add(offset)); + let out = _mm256_loadu_ps(output.as_ptr().add(offset)); + + let result = _mm256_add_ps(out, weight); + _mm256_storeu_ps(output.as_mut_ptr().add(offset), result); + } + } + + for i in chunks * 8..len { + output[i] += weights[i]; + } + } else { + accumulate_weights_scalar(weights, output); + } +} + +#[cfg(all(feature = "simd", target_arch = "aarch64"))] +pub fn accumulate_weights(weights: &[f32], output: &mut [f32]) { + use std::arch::aarch64::*; + + assert_eq!(weights.len(), output.len()); + + let len = weights.len(); + let chunks = len / 4; + + unsafe { + for i in 0..chunks { + let offset = i * 4; + + let weight = vld1q_f32(weights.as_ptr().add(offset)); + let out = vld1q_f32(output.as_ptr().add(offset)); + + let result = vaddq_f32(out, weight); + vst1q_f32(output.as_mut_ptr().add(offset), result); + } + } + + for i in chunks * 4..len { + output[i] += weights[i]; + } +} + +#[cfg(all(feature = "simd", not(any(target_arch = "x86_64", target_arch = "aarch64"))))] +pub fn accumulate_weights(weights: &[f32], output: &mut [f32]) { + accumulate_weights_scalar(weights, output); +} + +#[cfg(not(feature = "simd"))] +pub fn accumulate_weights(weights: &[f32], output: &mut [f32]) { + accumulate_weights_scalar(weights, output); +} + +pub fn accumulate_weights_scalar(weights: &[f32], output: &mut [f32]) { + assert_eq!(weights.len(), output.len()); + + for i in 0..weights.len() { + output[i] += weights[i]; + } +} + +/// Normalize buffer by weights: buffer[i] /= weights[i] (if weight > 0) +#[cfg(all(feature = "simd", target_arch = "x86_64"))] +pub fn normalize_by_weights(buffer: &mut [f32], weights: &[f32]) { + use std::arch::x86_64::*; + + assert_eq!(buffer.len(), weights.len()); + + if is_x86_feature_detected!("avx") { + let len = buffer.len(); + let chunks = len / 8; + + unsafe { + let zero = _mm256_setzero_ps(); + + for i in 0..chunks { + let offset = i * 8; + + let buf = _mm256_loadu_ps(buffer.as_ptr().add(offset)); + let weight = _mm256_loadu_ps(weights.as_ptr().add(offset)); + + // Check if weight > 0 + let mask = _mm256_cmp_ps(weight, zero, _CMP_GT_OQ); + + // Divide where weight > 0 + let result = _mm256_div_ps(buf, weight); + + // Blend: keep original where weight == 0, use result where weight > 0 + let blended = _mm256_blendv_ps(buf, result, mask); + + _mm256_storeu_ps(buffer.as_mut_ptr().add(offset), blended); + } + } + + for i in chunks * 8..len { + let w = weights[i]; + if w > 0.0 { + buffer[i] /= w; + } + } + } else { + normalize_by_weights_scalar(buffer, weights); + } +} + +#[cfg(all(feature = "simd", target_arch = "aarch64"))] +pub fn normalize_by_weights(buffer: &mut [f32], weights: &[f32]) { + use std::arch::aarch64::*; + + assert_eq!(buffer.len(), weights.len()); + + let len = buffer.len(); + let chunks = len / 4; + + unsafe { + let zero = vdupq_n_f32(0.0); + + for i in 0..chunks { + let offset = i * 4; + + let buf = vld1q_f32(buffer.as_ptr().add(offset)); + let weight = vld1q_f32(weights.as_ptr().add(offset)); + + // Check if weight > 0 + let mask = vcgtq_f32(weight, zero); + + // Divide where weight > 0 + let result = vdivq_f32(buf, weight); + + // Blend using mask + let blended = vbslq_f32(mask, result, buf); + + vst1q_f32(buffer.as_mut_ptr().add(offset), blended); + } + } + + for i in chunks * 4..len { + let w = weights[i]; + if w > 0.0 { + buffer[i] /= w; + } + } +} + +#[cfg(all(feature = "simd", not(any(target_arch = "x86_64", target_arch = "aarch64"))))] +pub fn normalize_by_weights(buffer: &mut [f32], weights: &[f32]) { + normalize_by_weights_scalar(buffer, weights); +} + +#[cfg(not(feature = "simd"))] +pub fn normalize_by_weights(buffer: &mut [f32], weights: &[f32]) { + normalize_by_weights_scalar(buffer, weights); +} + +pub fn normalize_by_weights_scalar(buffer: &mut [f32], weights: &[f32]) { + assert_eq!(buffer.len(), weights.len()); + + for i in 0..buffer.len() { + let w = weights[i]; + if w > 0.0 { + buffer[i] /= w; + } + } +} diff --git a/src/demucs.rs b/src/demucs.rs index beceb55..ded8a01 100644 --- a/src/demucs.rs +++ b/src/demucs.rs @@ -4,17 +4,29 @@ use ort::value::ValueType; use ndarray::{s, ArrayViewMut, ShapeBuilder}; use ort::{session::{builder::GraphOptimizationLevel, Session}, value::Tensor}; +use crate::audio_ops::{weighted_accumulate, accumulate_weights, normalize_by_weights}; + #[cfg(feature = "cuda")] use ort::{execution_providers::CUDAExecutionProvider}; +#[cfg(feature = "coreml")] +use ort::execution_providers::{CoreMLExecutionProvider, coreml::{CoreMLComputeUnits}}; + use crate::constant::DEFAULT_MODEL; +const SEGMENT_SAMPLES: usize = 343980; +const CHANNELS: usize = 2; +const NUM_STEMS: usize = 4; + #[derive(Debug)] pub struct Demucs { session: Session, input_name: String, output_name: String, input_buffer: Vec, + overlap: f32, + transition_power: f32, + segment_output_buffer: Vec<[Vec; CHANNELS]>, } #[derive(Debug, Clone)] @@ -34,7 +46,9 @@ pub enum Device { #[default] CPU, #[cfg(feature = "cuda")] - CUDA + CUDA, + #[cfg(feature = "coreml")] + CoreML } impl std::fmt::Display for Device { @@ -42,6 +56,8 @@ impl std::fmt::Display for Device { match self { #[cfg(feature = "cuda")] Device::CUDA => write!(f, "cuda"), + #[cfg(feature = "coreml")] + Device::CoreML => write!(f, "coreml"), Device::CPU => write!(f, "cpu"), } } @@ -54,6 +70,8 @@ impl TryFrom<&str> for Device { match value { #[cfg(feature = "cuda")] "cuda" => Ok(Device::CUDA), + #[cfg(feature = "coreml")] + "coreml" => Ok(Device::CoreML), "cpu" => Ok(Device::CPU), _ => Err("unsupported device".to_owned()), } @@ -88,12 +106,19 @@ impl TryFrom<&str> for Model { pub struct DemusOpts { pub device: Device, - pub threads: usize + pub threads: usize, + pub overlap: f32, + pub transition_power: f32, } impl Default for DemusOpts { fn default() -> Self { - Self { threads: 2, device: Device::CPU } + Self { + threads: 2, + device: Device::CPU, + overlap: 0.25, + transition_power: 1.0, + } } } @@ -113,6 +138,14 @@ impl Demucs { .build() .error_on_failure() ], + #[cfg(feature = "coreml")] + Device::CoreML => vec![ + CoreMLExecutionProvider::default() + // FIXME: There is currently a huge memory leak with CoreML runtime in ort crate + .with_compute_units(CoreMLComputeUnits::CPUAndGPU) // Use GPU for hardware acceleration + .build() + .error_on_failure() + ], Device::CPU => vec![] }) .commit()?; @@ -168,68 +201,192 @@ impl Demucs { } }?; + // Pre-allocate segment output buffer + let segment_output_buffer: Vec<[Vec; CHANNELS]> = (0..NUM_STEMS) + .map(|_| [ + Vec::with_capacity(SEGMENT_SAMPLES), + Vec::with_capacity(SEGMENT_SAMPLES) + ]) + .collect(); + Ok(Self { session, input_name, output_name, - input_buffer: Vec::with_capacity(2 * 343980), + input_buffer: Vec::with_capacity(CHANNELS * SEGMENT_SAMPLES), + overlap: ops.overlap, + transition_power: ops.transition_power, + segment_output_buffer, }) } - fn process(&mut self) -> Result>, Box> { - let tensor = Tensor::::from_array(ArrayViewMut::from_shape((1, 2, 343980).strides((343980 * 2, 1, 2)), &mut self.input_buffer)?.to_owned())?; + /// Process a single segment through the model + /// Reuses internal buffer to avoid allocations + fn process_segment(&mut self) -> Result<(), Box> { + let tensor = Tensor::::from_array( + ArrayViewMut::from_shape( + (1, CHANNELS, SEGMENT_SAMPLES).strides((SEGMENT_SAMPLES * CHANNELS, 1, CHANNELS)), + &mut self.input_buffer + )?.to_owned() + )?; let result = self.session.run(ort::inputs! { &self.input_name => tensor })?; let output = result[self.output_name.as_str()].try_extract_array::()?; - let mut stems = vec![Vec::new(); 4]; - for (i, stem) in stems.iter_mut().enumerate() { // Iterate over the 4 items - let mut offset = stem.len(); - stem.resize_with(offset+2 * 343980, ||0.0f32); - - let l_slice = output.slice(s![0, i, 0, ..]); // All L values for item i - let r_slice = output.slice(s![0, i, 1, ..]); // All R values for item i - - for (l, r) in l_slice.iter().zip(r_slice.iter()) { - stem[offset] = *l; - stem[offset + 1] = *r; - offset += 2; - } - } - if self.input_buffer.len() == 2 * 343980 { - self.input_buffer.clear(); - } else { - let leftover = self.input_buffer.len() - 2 * 343980; - let (left, right) = self.input_buffer.split_at_mut(2 * 343980); - left[..leftover].copy_from_slice(right); - self.input_buffer.resize(leftover, 0.0); + + // Reuse buffer - clear and refill instead of allocating + for i in 0..NUM_STEMS { + let l_slice = output.slice(s![0, i, 0, ..]); + let r_slice = output.slice(s![0, i, 1, ..]); + + self.segment_output_buffer[i][0].clear(); + self.segment_output_buffer[i][1].clear(); + self.segment_output_buffer[i][0].extend(l_slice.iter().copied()); + self.segment_output_buffer[i][1].extend(r_slice.iter().copied()); } - Ok(stems) + + Ok(()) } - pub fn send(&mut self, sample_buffer: &[f32]) -> Result>>, Box> { - if sample_buffer.len() % 2 != 0 { - return Err("uneven number of sample".into()); + /// Create triangular weight vector for windowing + fn create_weight_vector(segment_samples: usize, transition_power: f32) -> Vec { + let mut weight = vec![0.0f32; segment_samples]; + let half_segment = segment_samples / 2; + + // First half: linear ramp up + for i in 0..half_segment { + weight[i] = (i + 1) as f32 / half_segment as f32; } - self.input_buffer.extend_from_slice(sample_buffer); + // Second half: linear ramp down + for i in half_segment..segment_samples { + weight[i] = (segment_samples - i) as f32 / half_segment as f32; + } - if self.input_buffer.len() >= 2 * 343980 { - Ok(Some(self.process()?)) - } else { - Ok(None) + // Apply transition power + if transition_power != 1.0 { + for w in weight.iter_mut() { + *w = w.powf(transition_power); + } } + + weight } - pub fn flush(&mut self) -> Result>, Box> { - let buffer_size = self.input_buffer.len(); - self.input_buffer.resize(2 * 343980, 0.0); - let mut data = self.process()?; - for stem in data.iter_mut() { - stem.resize(buffer_size, 0.0f32); + /// Process entire audio with optional overlap for better quality + /// When overlap=0, processes segments sequentially without windowing (fast) + /// When overlap>0, uses overlapping windows with blending (better quality) + /// Returns planar format: Vec<[left_channel, right_channel]> + pub fn process(&mut self, audio: &[f32], mut progress_callback: F) -> Result; CHANNELS]>, Box> + where + F: FnMut(usize, usize), + { + if audio.len() % CHANNELS != 0 { + return Err("audio length must be multiple of channel count".into()); + } + + let length = audio.len() / CHANNELS; + + // Fast path: no overlap + if self.overlap == 0.0 { + // Pre-allocate output buffers with exact size (no reallocation needed) + let mut output: Vec<[Vec; CHANNELS]> = (0..NUM_STEMS) + .map(|_| [ + Vec::with_capacity(length), + Vec::with_capacity(length) + ]) + .collect(); + + let mut segment_offset = 0; + + while segment_offset < length { + let chunk_length = std::cmp::min(SEGMENT_SAMPLES, length - segment_offset); + + // Prepare input buffer - zero pad if needed + self.input_buffer.clear(); + self.input_buffer.resize(CHANNELS * SEGMENT_SAMPLES, 0.0); + + // Copy audio chunk (audio is interleaved) + let sample_count = chunk_length * CHANNELS; + let src_offset = segment_offset * CHANNELS; + self.input_buffer[..sample_count].copy_from_slice(&audio[src_offset..src_offset + sample_count]); + + // Process segment - fills segment_output_buffer + self.process_segment()?; + + // Copy from reused buffer to output (trim to actual chunk length) + for stem_idx in 0..NUM_STEMS { + output[stem_idx][0].extend_from_slice(&self.segment_output_buffer[stem_idx][0][..chunk_length]); + output[stem_idx][1].extend_from_slice(&self.segment_output_buffer[stem_idx][1][..chunk_length]); + } + + segment_offset += SEGMENT_SAMPLES; + progress_callback(segment_offset, length); + } + + return Ok(output); + } + + // Quality path: with overlap and windowing + let stride_samples = ((1.0 - self.overlap) * SEGMENT_SAMPLES as f32) as usize; + let weight = Self::create_weight_vector(SEGMENT_SAMPLES, self.transition_power); + + // Pre-allocate output buffers with exact size + let mut output: Vec<[Vec; CHANNELS]> = (0..NUM_STEMS) + .map(|_| [vec![0.0f32; length], vec![0.0f32; length]]) + .collect(); + let mut sum_weight = vec![0.0f32; length]; + + let mut segment_offset = 0; + while segment_offset < length { + let chunk_length = std::cmp::min(SEGMENT_SAMPLES, length - segment_offset); + + // Prepare input buffer + self.input_buffer.clear(); + self.input_buffer.resize(CHANNELS * SEGMENT_SAMPLES, 0.0); + + // Copy audio chunk (audio is interleaved) + let sample_count = chunk_length * CHANNELS; + let src_offset = segment_offset * CHANNELS; + self.input_buffer[..sample_count].copy_from_slice(&audio[src_offset..src_offset + sample_count]); + + // Process segment - fills segment_output_buffer + self.process_segment()?; + + // Apply weights and accumulate (planar format - better cache locality) + for stem_idx in 0..NUM_STEMS { + // Left channel + weighted_accumulate( + &self.segment_output_buffer[stem_idx][0][..chunk_length], + &weight[..chunk_length], + &mut output[stem_idx][0][segment_offset..segment_offset + chunk_length] + ); + // Right channel + weighted_accumulate( + &self.segment_output_buffer[stem_idx][1][..chunk_length], + &weight[..chunk_length], + &mut output[stem_idx][1][segment_offset..segment_offset + chunk_length] + ); + } + + // Accumulate weights + accumulate_weights( + &weight[..chunk_length], + &mut sum_weight[segment_offset..segment_offset + chunk_length] + ); + + segment_offset += stride_samples; + progress_callback(segment_offset.min(length), length); } - Ok(data) + + // Normalize by sum of weights (planar format) + for stem_idx in 0..NUM_STEMS { + normalize_by_weights(&mut output[stem_idx][0], &sum_weight); + normalize_by_weights(&mut output[stem_idx][1], &sum_weight); + } + + Ok(output) } } diff --git a/src/lib.rs b/src/lib.rs index 9cdd298..47c7151 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod audio_ops; pub mod constant; pub mod demucs; pub mod nistem; diff --git a/src/nistem.rs b/src/nistem.rs index f935c78..9348962 100644 --- a/src/nistem.rs +++ b/src/nistem.rs @@ -485,7 +485,7 @@ impl NIStem { } for (stream_idx, ((idx, encoder, resampler, timestamp), mut frames)) in inner.idx_encoders.iter_mut().zip(stems).enumerate() { let frame_size = 2 * encoder.frame_size() as usize; - if !inner.overrun[stream_idx].is_empty(){ + if !inner.overrun[stream_idx].is_empty() { frames = { let mut v = inner.overrun[stream_idx].clone(); v.extend(frames); diff --git a/src/track.rs b/src/track.rs index 3a2ce3b..af501e3 100644 --- a/src/track.rs +++ b/src/track.rs @@ -1,27 +1,29 @@ use std::{collections::HashMap, path::PathBuf}; use ffmpeg_next::{ - codec, decoder, ffi::av_rescale_q, format::{self, context}, frame::Audio, media, software::resampling, Packet, Rational + codec, decoder, format::{self, context}, frame::Audio, media, software::resampling, ChannelLayout, Packet, Rational }; use taglib::AttachedPicture; use crate::constant::{Metadata, MetadataValue}; +const OUTPUT_SAMPLE_RATE: i32 = 44100; +const OUTPUT_CHANNELS: usize = 2; + pub struct Track { path: PathBuf, ctx: context::Input, index: usize, resampler: resampling::context::Context, decoder: decoder::Audio, - overrun: [f32; 10240], - overrun_len: usize, + overrun: Vec, + eof_sent: bool, } impl Track { pub fn new(path: &PathBuf) -> Result> { let ctx = format::input(&path)?; - // format::context::input::dump(&ctx, 0, Some(path.to_str().ok_or("unable to read path")?)); let stream = ctx .streams() .best(media::Type::Audio) @@ -32,13 +34,20 @@ impl Track { ffmpeg_next::codec::context::Context::from_parameters(stream.parameters())?; let decoder = context_decoder.decoder().audio()?; - let resampler = ffmpeg_next::software::resampling::context::Context::get( + // Use proper channel layout if decoder has empty layout + let input_layout = if decoder.channel_layout().channels() == 2 && decoder.channel_layout().bits() == 0 { + ChannelLayout::STEREO + } else { + decoder.channel_layout() + }; + + let resampler = resampling::context::Context::get( decoder.format(), - decoder.channel_layout(), + input_layout, decoder.rate(), format::Sample::F32(format::sample::Type::Packed), - ffmpeg_next::ChannelLayout::STEREO, - 44100, + ChannelLayout::STEREO, + OUTPUT_SAMPLE_RATE as u32, )?; Ok(Self { @@ -47,8 +56,8 @@ impl Track { index, resampler, decoder, - overrun: [0f32; 10240], - overrun_len: Default::default(), + overrun: Vec::new(), + eof_sent: false, }) } @@ -68,9 +77,11 @@ impl Track { } pub fn total_samples(&self) -> i64 { let stream = self.ctx.stream(self.index).unwrap(); - unsafe { - av_rescale_q(stream.duration() * self.decoder.channels() as i64, stream.time_base().into(), self.decoder.time_base().into()) - } + // Calculate duration in seconds + let duration_sec = stream.time_base().numerator() as i64 * stream.duration() + / stream.time_base().denominator() as i64; + // Multiply by output sample rate and channels for interleaved sample count + duration_sec * OUTPUT_SAMPLE_RATE as i64 * OUTPUT_CHANNELS as i64 } } @@ -80,85 +91,120 @@ impl Track { mut original_packets: Option<&mut Vec>, buf: &mut [f32], ) -> Result> { - let mut read = 0; - - if self.overrun_len > 0 && self.overrun_len <= buf.len() { - buf[..self.overrun_len].copy_from_slice(&self.overrun[..self.overrun_len]); - read = self.overrun_len; - self.overrun_len = 0; - } else if self.overrun_len > buf.len() { - buf.copy_from_slice(&self.overrun[..buf.len()]); - self.overrun_len -= buf.len(); - return Ok(buf.len()); - } - let mut packets = self.ctx.packets(); + let mut write_pos = 0; - let mut process = |mut resampled: Audio, buf: &mut [f32], read: usize| { - let output = resampled.plane_mut(0); + if !self.overrun.is_empty() { + let to_copy = std::cmp::min(self.overrun.len(), buf.len()); + buf[..to_copy].copy_from_slice(&self.overrun[..to_copy]); + write_pos += to_copy; - if output.len() > buf.len() - read { - let (left, right) = output.split_at_mut(buf.len() - read); - buf[read..].copy_from_slice(left); - self.overrun[..right.len()].copy_from_slice(right); - self.overrun_len = right.len(); - return buf.len() - read; + self.overrun.drain(..to_copy); + + if write_pos >= buf.len() { + return Ok(write_pos); } + } - buf[read..read + output.len()].copy_from_slice(output); - output.len() - }; + if self.eof_sent { + return Ok(write_pos); + } + + let mut packets = self.ctx.packets(); - while read < buf.len() { + while write_pos < buf.len() { + // Read next packet let eof = if let Some((stream, packet)) = packets.next() { if stream.index() != self.index { continue; } - original_packets = if let Some(original_packets) = original_packets { + if let Some(ref mut original_packets) = original_packets { original_packets.push(packet.clone()); - Some(original_packets) - } else { - None - }; - // println!("packet {:?}", packet.pts()); + } self.decoder.send_packet(&packet)?; false } else { - self.decoder.send_eof()?; + if !self.eof_sent { + self.decoder.send_eof()?; + self.eof_sent = true; + } true }; let mut decoded = Audio::empty(); while self.decoder.receive_frame(&mut decoded).is_ok() { + if decoded.channel_layout().channels() == 2 && decoded.channel_layout().bits() == 0 { + decoded.set_channel_layout(ChannelLayout::STEREO); + } + let mut resampled = Audio::empty(); self.resampler.run(&decoded, &mut resampled)?; - resampled.set_samples(resampled.samples() * decoded.planes()); // FIXME seems to be a bug upstream? - // println!("frame {:?}", resampled.pts()); - read += process(resampled, buf, read); + + // For packed stereo: resampled.samples() returns frame count + // The actual float count in the buffer is samples * channels + let frame_count = resampled.samples(); + let channels = resampled.channel_layout().channels() as usize; + let float_count = frame_count * channels; + + let output = resampled.plane::(0); + + // SAFETY: For packed format, the underlying buffer contains frame_count * channels floats, + // but plane() returns a slice with wrong length. We reconstruct the correct slice. + let output_correct = unsafe { + std::slice::from_raw_parts(output.as_ptr(), float_count) + }; + + let remaining_space = buf.len() - write_pos; + + if output_correct.len() <= remaining_space { + buf[write_pos..write_pos + output_correct.len()].copy_from_slice(output_correct); + write_pos += output_correct.len(); + } else { + buf[write_pos..].copy_from_slice(&output_correct[..remaining_space]); + write_pos = buf.len(); + + self.overrun.extend_from_slice(&output_correct[remaining_space..]); + return Ok(write_pos); + } } + if eof { - let mut finished = false; - while !finished { - let mut resampled = Audio::new( - self.resampler.output().format, - 1024, - self.resampler.output().channel_layout, - ); - - finished = match self.resampler.flush(&mut resampled) { - Ok(None) => true, - Ok(_) | Err(_) => false, - }; - if resampled.planes() == 0 { - break; + // Flush resampler + loop { + let mut resampled = Audio::empty(); + match self.resampler.flush(&mut resampled) { + Ok(Some(_)) => { + // For packed stereo: same fix as above + let frame_count = resampled.samples(); + let channels = resampled.channel_layout().channels() as usize; + let float_count = frame_count * channels; + + let output = resampled.plane::(0); + + let output_correct = unsafe { + std::slice::from_raw_parts(output.as_ptr(), float_count) + }; + + let remaining_space = buf.len() - write_pos; + + if output_correct.len() <= remaining_space { + buf[write_pos..write_pos + output_correct.len()].copy_from_slice(output_correct); + write_pos += output_correct.len(); + } else { + buf[write_pos..].copy_from_slice(&output_correct[..remaining_space]); + write_pos = buf.len(); + self.overrun.extend_from_slice(&output_correct[remaining_space..]); + return Ok(write_pos); + } + } + _ => break, } - resampled.set_samples(resampled.samples() * decoded.planes()); // FIXME seems to be a bug upstream? - read += process(resampled, buf, read); } break; } } - Ok(read) + + Ok(write_pos) } pub fn tags(&self) -> HashMap { taglib::File::new(&self.path)