From 414a353acef3f86a63448fc080357fb4ae652f54 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 3 Feb 2026 19:59:17 -0500 Subject: [PATCH 1/7] changes --- src/torchcodec/_core/CMakeLists.txt | 1 + src/torchcodec/_core/WavDecoder.cpp | 355 ++++++++++++++++++++++ src/torchcodec/_core/WavDecoder.h | 28 ++ src/torchcodec/_core/__init__.py | 2 + src/torchcodec/_core/_metadata.py | 16 + src/torchcodec/_core/custom_ops.cpp | 27 ++ src/torchcodec/_core/ops.py | 2 + src/torchcodec/decoders/_audio_decoder.py | 77 +++++ test/test_decoders.py | 33 ++ 9 files changed, 541 insertions(+) create mode 100644 src/torchcodec/_core/WavDecoder.cpp create mode 100644 src/torchcodec/_core/WavDecoder.h diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 67d5bb5e2..95763ebfa 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -136,6 +136,7 @@ function(make_torchcodec_libraries ValidationUtils.cpp Transform.cpp Metadata.cpp + WavDecoder.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/WavDecoder.cpp b/src/torchcodec/_core/WavDecoder.cpp new file mode 100644 index 000000000..3c5b9d28b --- /dev/null +++ b/src/torchcodec/_core/WavDecoder.cpp @@ -0,0 +1,355 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "WavDecoder.h" + +#include +#include +#include +#include +#include + +namespace facebook::torchcodec { +namespace { + +// PCM format codes in WAV files +constexpr uint16_t WAVE_FORMAT_PCM = 1; +constexpr uint16_t WAVE_FORMAT_IEEE_FLOAT = 3; + +// Read a little-endian value from raw bytes +template +T readLE(const uint8_t* data) { + T value; + std::memcpy(&value, data, sizeof(T)); + return value; +} + +// Check for a 4-byte identifier (FOURCC) at a given offset +bool checkFourCC( + const uint8_t* data, + int64_t size, + int64_t offset, + const char* expected) { + if (offset + 4 > size) { + return false; + } + return std::memcmp(data + offset, expected, 4) == 0; +} + +// WAV header info extracted during parsing +struct WavHeaderInfo { + int64_t dataOffset; // Byte offset to PCM data + int64_t dataSize; // Size of PCM data in bytes + int64_t numSamples; // Total samples per channel + int sampleRate; + int numChannels; + int bitsPerSample; + uint16_t formatCode; // 1=PCM int, 3=IEEE float +}; + +// Read entire file into a vector +std::optional> readFile(const std::string& path) { + std::ifstream file(path, std::ios::binary | std::ios::ate); + if (!file) { + return std::nullopt; + } + + auto size = file.tellg(); + if (size <= 0) { + return std::nullopt; + } + + std::vector data(static_cast(size)); + file.seekg(0, std::ios::beg); + if (!file.read(reinterpret_cast(data.data()), size)) { + return std::nullopt; + } + + return data; +} + +// Parse WAV header from raw bytes +std::optional parseWavHeader(const uint8_t* data, int64_t size) { + // Need at least 44 bytes for minimal WAV header + if (size < 44) { + return std::nullopt; + } + + // Check RIFF/WAVE signature + if (!checkFourCC(data, size, 0, "RIFF") || + !checkFourCC(data, size, 8, "WAVE")) { + return std::nullopt; + } + + // Parse chunks to find fmt and data + int64_t pos = 12; + int numChannels = 0; + int sampleRate = 0; + int bitsPerSample = 0; + uint16_t formatCode = 0; + int64_t dataOffset = 0; + int64_t dataSize = 0; + bool foundFmt = false; + bool foundData = false; + + while (pos + 8 <= size) { + uint32_t chunkSize = readLE(data + pos + 4); + + if (checkFourCC(data, size, pos, "fmt ")) { + if (pos + 8 + 16 > size) { + return std::nullopt; + } + const uint8_t* fmt = data + pos + 8; + formatCode = readLE(fmt); + numChannels = readLE(fmt + 2); + sampleRate = readLE(fmt + 4); + bitsPerSample = readLE(fmt + 14); + foundFmt = true; + } else if (checkFourCC(data, size, pos, "data")) { + dataOffset = pos + 8; + dataSize = chunkSize; + foundData = true; + break; + } + + pos += 8 + chunkSize + (chunkSize & 1); + } + + if (!foundFmt || !foundData) { + return std::nullopt; + } + + // Validate basic parameters + if (numChannels <= 0 || sampleRate <= 0 || bitsPerSample <= 0) { + return std::nullopt; + } + + // Validate format/bitsPerSample combinations + if (formatCode == WAVE_FORMAT_PCM) { + if (bitsPerSample != 8 && bitsPerSample != 16 && bitsPerSample != 24 && + bitsPerSample != 32) { + return std::nullopt; + } + } else if (formatCode == WAVE_FORMAT_IEEE_FLOAT) { + if (bitsPerSample != 32 && bitsPerSample != 64) { + return std::nullopt; + } + } else { + // Unsupported format (extensible, etc.) + return std::nullopt; + } + + int bytesPerSample = bitsPerSample / 8; + int64_t bytesPerFrame = numChannels * bytesPerSample; + int64_t numSamples = dataSize / bytesPerFrame; + + return WavHeaderInfo{ + dataOffset, + dataSize, + numSamples, + sampleRate, + numChannels, + bitsPerSample, + formatCode}; +} + +// Convert 24-bit PCM samples to float32 tensor +torch::Tensor convert24BitPcmToFloat32( + const uint8_t* pcmData, + int64_t numSamples, + int numChannels) { + auto output = torch::empty({numChannels, numSamples}, torch::kFloat32); + float* outPtr = output.data_ptr(); + + constexpr float scale = 1.0f / 8388608.0f; + + for (int64_t s = 0; s < numSamples; ++s) { + for (int c = 0; c < numChannels; ++c) { + size_t byteIdx = (s * numChannels + c) * 3; + int32_t sample = pcmData[byteIdx] | (pcmData[byteIdx + 1] << 8) | + (pcmData[byteIdx + 2] << 16); + // Sign extend from 24 to 32 bits + if (sample & 0x800000) { + sample |= 0xFF000000; + } + outPtr[c * numSamples + s] = sample * scale; + } + } + + return output; +} + +// Convert PCM data to float32 tensor +// Returns tensor of shape (numChannels, numSamples) +torch::Tensor convertPcmToFloat32( + const uint8_t* pcmData, + int64_t numSamples, + int numChannels, + uint16_t formatCode, + int bitsPerSample) { + const bool isMono = (numChannels == 1); + + if (formatCode == WAVE_FORMAT_PCM) { + switch (bitsPerSample) { + case 8: { + auto shape = isMono ? std::vector{numSamples} + : std::vector{numSamples, numChannels}; + auto uintTensor = torch::from_blob( + const_cast(pcmData), shape, torch::kUInt8); + auto floatTensor = + uintTensor.to(torch::kFloat32).sub_(128.0f).div_(128.0f); + if (isMono) { + return floatTensor.unsqueeze(0); + } + return floatTensor.t().contiguous(); + } + case 16: { + auto shape = isMono ? std::vector{numSamples} + : std::vector{numSamples, numChannels}; + auto intTensor = torch::from_blob( + const_cast(pcmData), shape, torch::kInt16); + auto floatTensor = intTensor.to(torch::kFloat32).div_(32768.0f); + if (isMono) { + return floatTensor.unsqueeze(0); + } + return floatTensor.t().contiguous(); + } + case 24: { + return convert24BitPcmToFloat32(pcmData, numSamples, numChannels); + } + case 32: { + auto shape = isMono ? std::vector{numSamples} + : std::vector{numSamples, numChannels}; + auto intTensor = torch::from_blob( + const_cast(pcmData), shape, torch::kInt32); + auto floatTensor = intTensor.to(torch::kFloat32).div_(2147483648.0f); + if (isMono) { + return floatTensor.unsqueeze(0); + } + return floatTensor.t().contiguous(); + } + } + } else if (formatCode == WAVE_FORMAT_IEEE_FLOAT) { + switch (bitsPerSample) { + case 32: { + auto shape = isMono ? std::vector{numSamples} + : std::vector{numSamples, numChannels}; + auto floatTensor = torch::from_blob( + const_cast(pcmData), shape, torch::kFloat32); + if (isMono) { + return floatTensor.clone().unsqueeze(0); + } + return floatTensor.t().contiguous(); + } + case 64: { + auto shape = isMono ? std::vector{numSamples} + : std::vector{numSamples, numChannels}; + auto doubleTensor = torch::from_blob( + const_cast(pcmData), shape, torch::kFloat64); + auto floatTensor = doubleTensor.to(torch::kFloat32); + if (isMono) { + return floatTensor.unsqueeze(0); + } + return floatTensor.t().contiguous(); + } + } + } + + TORCH_CHECK(false, "Unsupported PCM format"); +} + +// Build JSON metadata string compatible with AudioStreamMetadata +std::string buildWavMetadataJson(const WavHeaderInfo& header) { + double durationSeconds = + static_cast(header.numSamples) / header.sampleRate; + double bitRate = static_cast( + header.sampleRate * header.numChannels * header.bitsPerSample); + + // Determine sample format string (FFmpeg style) + std::string sampleFormat; + if (header.formatCode == WAVE_FORMAT_IEEE_FLOAT) { + sampleFormat = (header.bitsPerSample == 32) ? "flt" : "dbl"; + } else { + // PCM integer formats + if (header.bitsPerSample == 8) { + sampleFormat = "u8"; + } else { + sampleFormat = "s" + std::to_string(header.bitsPerSample); + } + } + + std::stringstream ss; + ss << "{\n"; + ss << "\"durationSecondsFromHeader\": " << durationSeconds << ",\n"; + ss << "\"durationSeconds\": " << durationSeconds << ",\n"; + ss << "\"beginStreamSecondsFromHeader\": 0.0,\n"; + ss << "\"beginStreamSeconds\": 0.0,\n"; + ss << "\"bitRate\": " << bitRate << ",\n"; + ss << "\"codec\": \"pcm\",\n"; + ss << "\"sampleRate\": " << header.sampleRate << ",\n"; + ss << "\"numChannels\": " << header.numChannels << ",\n"; + ss << "\"sampleFormat\": \"" << sampleFormat << "\",\n"; + ss << "\"mediaType\": \"audio\"\n"; + ss << "}"; + + return ss.str(); +} + +} // namespace + +std::optional decodeWavFromTensor(const torch::Tensor& data) { + TORCH_CHECK( + data.is_contiguous() && data.dtype() == torch::kUInt8, + "Input tensor must be contiguous uint8"); + + const uint8_t* ptr = data.data_ptr(); + int64_t size = data.numel(); + + auto header = parseWavHeader(ptr, size); + if (!header) { + return std::nullopt; + } + + // Validate data bounds + if (header->dataOffset + header->dataSize > size) { + return std::nullopt; + } + + auto samples = convertPcmToFloat32( + ptr + header->dataOffset, + header->numSamples, + header->numChannels, + header->formatCode, + header->bitsPerSample); + + return WavSamples{samples, buildWavMetadataJson(*header)}; +} + +std::optional decodeWavFromFile(const std::string& path) { + auto fileData = readFile(path); + if (!fileData) { + return std::nullopt; + } + + const uint8_t* ptr = fileData->data(); + int64_t size = static_cast(fileData->size()); + + auto header = parseWavHeader(ptr, size); + if (!header || header->dataOffset + header->dataSize > size) { + return std::nullopt; + } + + auto samples = convertPcmToFloat32( + ptr + header->dataOffset, + header->numSamples, + header->numChannels, + header->formatCode, + header->bitsPerSample); + + return WavSamples{samples, buildWavMetadataJson(*header)}; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/WavDecoder.h b/src/torchcodec/_core/WavDecoder.h new file mode 100644 index 000000000..fe2e1ce6b --- /dev/null +++ b/src/torchcodec/_core/WavDecoder.h @@ -0,0 +1,28 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace facebook::torchcodec { + +struct WavSamples { + torch::Tensor samples; // Shape: (num_channels, num_samples) + std::string metadataJson; // JSON compatible with AudioStreamMetadata +}; + +// Decode WAV from bytes tensor (zero-copy when possible for mono PCM). +// Returns nullopt if the data is not a valid/supported WAV file. +std::optional decodeWavFromTensor(const torch::Tensor& data); + +// Decode WAV from file path (uses mmap for zero-copy access). +// Returns nullopt if the file is not a valid/supported WAV file. +std::optional decodeWavFromFile(const std::string& path); + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index add9efa90..1109c372d 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -24,6 +24,8 @@ create_from_file, create_from_file_like, create_from_tensor, + decode_wav_from_file, + decode_wav_from_tensor, encode_audio_to_file, encode_audio_to_file_like, encode_audio_to_tensor, diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 482d0e1cb..2c0769cff 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -140,6 +140,22 @@ class AudioStreamMetadata(StreamMetadata): def __repr__(self): return super().__repr__() + @classmethod + def from_json(cls, json: dict, stream_index: int = 0) -> "AudioStreamMetadata": + """Create AudioStreamMetadata from a JSON dictionary returned by""" + return cls( + duration_seconds_from_header=json.get("durationSecondsFromHeader"), + duration_seconds=json.get("durationSeconds"), + bit_rate=json.get("bitRate"), + begin_stream_seconds_from_header=json.get("beginStreamSecondsFromHeader"), + begin_stream_seconds=json.get("beginStreamSeconds"), + codec=json.get("codec"), + stream_index=stream_index, + sample_rate=json.get("sampleRate"), + num_channels=json.get("numChannels"), + sample_format=json.get("sampleFormat"), + ) + @dataclass class ContainerMetadata: diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index e35f62388..dffbad9f6 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -14,6 +14,7 @@ #include "Encoder.h" #include "SingleStreamDecoder.h" #include "ValidationUtils.h" +#include "WavDecoder.h" #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" @@ -79,6 +80,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool"); m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); + m.def("decode_wav_from_tensor(Tensor data) -> (Tensor, str)"); + m.def("decode_wav_from_file(str path) -> (Tensor, str)"); } namespace { @@ -1052,6 +1055,28 @@ void scan_all_streams_to_update_metadata(torch::Tensor& decoder) { videoDecoder->scanFileAndUpdateMetadataAndIndex(); } +// Slim WAV decode functions - bypass SingleStreamDecoder for direct PCM access +// Returns (samples, metadata_json) or (empty tensor, "") if not a valid WAV +std::tuple decode_wav_from_tensor( + const torch::Tensor& data) { + auto result = decodeWavFromTensor(data); + if (!result) { + return std::make_tuple( + torch::empty({0, 0}, torch::kFloat32), std::string()); + } + return std::make_tuple(result->samples, result->metadataJson); +} + +std::tuple decode_wav_from_file( + const std::string& path) { + auto result = decodeWavFromFile(path); + if (!result) { + return std::make_tuple( + torch::empty({0, 0}, torch::kFloat32), std::string()); + } + return std::make_tuple(result->samples, result->metadataJson); +} + TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); m.impl("create_from_tensor", &create_from_tensor); @@ -1061,6 +1086,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("encode_video_to_file", &encode_video_to_file); m.impl("encode_video_to_tensor", &encode_video_to_tensor); m.impl("_encode_video_to_file_like", &_encode_video_to_file_like); + m.impl("decode_wav_from_file", &decode_wav_from_file); } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { @@ -1092,6 +1118,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { &scan_all_streams_to_update_metadata); m.impl("_get_backend_details", &get_backend_details); + m.impl("decode_wav_from_tensor", &decode_wav_from_tensor); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 3188dfc7b..2e0fbc4fd 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -168,6 +168,8 @@ def expose_ffmpeg_dlls(): # noqa: F811 torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default ) _get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default +decode_wav_from_tensor = torch.ops.torchcodec_ns.decode_wav_from_tensor.default +decode_wav_from_file = torch.ops.torchcodec_ns.decode_wav_from_file.default # ============================= diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index e1d0e0461..4b998c24a 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -6,12 +6,14 @@ import io +import json from pathlib import Path import torch from torch import Tensor from torchcodec import _core as core, AudioSamples +from torchcodec._core._metadata import AudioStreamMetadata from torchcodec.decoders._decoder_utils import ( create_decoder, ERROR_REPORTING_INSTRUCTIONS, @@ -61,6 +63,25 @@ def __init__( num_channels: int | None = None, ): torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder") + + # Try WAV fast path: only when no resampling/channel conversion needed + self._wav_samples: AudioSamples | None = None + if stream_index is None and sample_rate is None and num_channels is None: + samples, metadata_json = self._try_decode_wav(source) + if metadata_json: + metadata = json.loads(metadata_json) + self._wav_samples = AudioSamples( + data=samples, + pts_seconds=0.0, + duration_seconds=metadata["durationSeconds"], + sample_rate=metadata["sampleRate"], + ) + self.stream_index = 0 + self._desired_sample_rate = metadata["sampleRate"] + self._decoder = None # type: ignore[assignment] + self.metadata = AudioStreamMetadata.from_json(metadata) + return + self._decoder = create_decoder(source=source, seek_mode="approximate") container_metadata = core.get_container_metadata(self._decoder) @@ -96,6 +117,35 @@ def __init__( num_channels=num_channels, ) + @staticmethod + def _try_decode_wav( + source: str | Path | io.RawIOBase | io.BufferedReader | bytes | Tensor, + ) -> tuple[Tensor, str]: + """Try decoding as WAV. Returns (samples, metadata_json). + + Empty metadata_json means not a valid WAV file. + """ + if isinstance(source, Tensor): + return core.decode_wav_from_tensor(source) + elif isinstance(source, bytes): + return core.decode_wav_from_tensor( + torch.frombuffer(source, dtype=torch.uint8) + ) + elif isinstance(source, (str, Path)): + path = str(source) + if path.startswith(("http://", "https://", "s3://")): + return torch.empty(0), "" + return core.decode_wav_from_file(path) + elif isinstance(source, (io.RawIOBase, io.BufferedReader)) or ( + hasattr(source, "read") and hasattr(source, "seek") + ): + data = source.read() + source.seek(0) + return core.decode_wav_from_tensor( + torch.frombuffer(data, dtype=torch.uint8) + ) + return torch.empty(0), "" + def get_all_samples(self) -> AudioSamples: """Returns all the audio samples from the source. @@ -105,6 +155,9 @@ def get_all_samples(self) -> AudioSamples: Returns: AudioSamples: The samples within the file. """ + # Use WAV fast path if available + if self._wav_samples is not None: + return self._wav_samples return self.get_samples_played_in_range() def get_samples_played_in_range( @@ -134,6 +187,30 @@ def get_samples_played_in_range( raise ValueError( f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." ) + + # Handle WAV fast path + if self._wav_samples is not None: + sample_rate = self._wav_samples.sample_rate + num_samples = self._wav_samples.data.shape[1] + + start_sample = round(start_seconds * sample_rate) + if stop_seconds is None: + stop_sample = num_samples + else: + stop_sample = round(stop_seconds * sample_rate) + + start_sample = max(0, min(start_sample, num_samples)) + stop_sample = max(0, min(stop_sample, num_samples)) + + data = self._wav_samples.data[:, start_sample:stop_sample] + output_pts = start_sample / sample_rate + return AudioSamples( + data=data, + pts_seconds=output_pts, + duration_seconds=data.shape[1] / sample_rate, + sample_rate=sample_rate, + ) + frames, first_pts = core.get_frames_by_pts_in_range_audio( self._decoder, start_seconds=start_seconds, diff --git a/test/test_decoders.py b/test/test_decoders.py index 9e901f826..2abc6c26b 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -2317,3 +2317,36 @@ def test_num_channels_errors(self, asset): # FFmpeg fails to find a default layout for certain channel counts, # which causes SwrContext to fail to initialize. decoder.get_all_samples() + + # WAV fast path tests + @pytest.mark.parametrize("asset", (SINE_MONO_S16, SINE_MONO_S32)) + def test_wav_fast_path_from_bytes(self, asset): + """Test that WAV files use the fast path when loaded from bytes.""" + with open(asset.path, "rb") as f: + wav_bytes = f.read() + + decoder = AudioDecoder(wav_bytes) + samples = decoder.get_all_samples() + assert samples.data.dtype == torch.float32 + assert samples.data.shape[0] == asset.num_channels + assert samples.sample_rate == asset.sample_rate + + @pytest.mark.parametrize("asset", (SINE_MONO_S16, SINE_MONO_S32)) + def test_wav_fast_path_range_decoding(self, asset): + """Test that range decoding works correctly with fast path.""" + with open(asset.path, "rb") as f: + wav_bytes = f.read() + + decoder = AudioDecoder(wav_bytes) + + # Decode a range + start_seconds = 1.0 + stop_seconds = 2.0 + samples = decoder.get_samples_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + expected_num_samples = round((stop_seconds - start_seconds) * asset.sample_rate) + assert samples.data.shape[1] == expected_num_samples + assert samples.pts_seconds == start_seconds + assert samples.duration_seconds == pytest.approx(stop_seconds - start_seconds) From d8bda2d478b6953b874584a7b81ecc0cb62a5230 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Wed, 4 Feb 2026 12:13:54 -0500 Subject: [PATCH 2/7] renaming --- src/torchcodec/_core/WavDecoder.cpp | 44 +++++++++++++- src/torchcodec/_core/WavDecoder.h | 28 +++++++-- src/torchcodec/_core/__init__.py | 4 +- src/torchcodec/_core/custom_ops.cpp | 33 +++++++---- src/torchcodec/_core/ops.py | 8 ++- src/torchcodec/decoders/_audio_decoder.py | 70 +++++++++++++++-------- 6 files changed, 140 insertions(+), 47 deletions(-) diff --git a/src/torchcodec/_core/WavDecoder.cpp b/src/torchcodec/_core/WavDecoder.cpp index 3c5b9d28b..a95b90dee 100644 --- a/src/torchcodec/_core/WavDecoder.cpp +++ b/src/torchcodec/_core/WavDecoder.cpp @@ -298,9 +298,35 @@ std::string buildWavMetadataJson(const WavHeaderInfo& header) { return ss.str(); } +// Validate optional parameters against WAV header +// Returns true if parameters are compatible, false otherwise +bool validateWavParams( + const WavHeaderInfo& header, + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) { + // WAV files only have one stream at index 0 + if (stream_index.has_value() && stream_index.value() != 0) { + return false; + } + // Check sample rate matches if specified + if (sample_rate.has_value() && sample_rate.value() != header.sampleRate) { + return false; + } + // Check channel count matches if specified + if (num_channels.has_value() && num_channels.value() != header.numChannels) { + return false; + } + return true; +} + } // namespace -std::optional decodeWavFromTensor(const torch::Tensor& data) { +std::optional validateAndDecodeWavFromTensor( + const torch::Tensor& data, + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) { TORCH_CHECK( data.is_contiguous() && data.dtype() == torch::kUInt8, "Input tensor must be contiguous uint8"); @@ -313,6 +339,11 @@ std::optional decodeWavFromTensor(const torch::Tensor& data) { return std::nullopt; } + // Validate optional parameters + if (!validateWavParams(*header, stream_index, sample_rate, num_channels)) { + return std::nullopt; + } + // Validate data bounds if (header->dataOffset + header->dataSize > size) { return std::nullopt; @@ -328,7 +359,11 @@ std::optional decodeWavFromTensor(const torch::Tensor& data) { return WavSamples{samples, buildWavMetadataJson(*header)}; } -std::optional decodeWavFromFile(const std::string& path) { +std::optional validateAndDecodeWavFromFile( + const std::string& path, + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) { auto fileData = readFile(path); if (!fileData) { return std::nullopt; @@ -342,6 +377,11 @@ std::optional decodeWavFromFile(const std::string& path) { return std::nullopt; } + // Validate optional parameters + if (!validateWavParams(*header, stream_index, sample_rate, num_channels)) { + return std::nullopt; + } + auto samples = convertPcmToFloat32( ptr + header->dataOffset, header->numSamples, diff --git a/src/torchcodec/_core/WavDecoder.h b/src/torchcodec/_core/WavDecoder.h index fe2e1ce6b..8e4b8b13d 100644 --- a/src/torchcodec/_core/WavDecoder.h +++ b/src/torchcodec/_core/WavDecoder.h @@ -17,12 +17,28 @@ struct WavSamples { std::string metadataJson; // JSON compatible with AudioStreamMetadata }; -// Decode WAV from bytes tensor (zero-copy when possible for mono PCM). -// Returns nullopt if the data is not a valid/supported WAV file. -std::optional decodeWavFromTensor(const torch::Tensor& data); +// Validate parameters and decode WAV from bytes tensor. +// Returns nullopt if: +// - The data is not a valid/supported WAV file +// - stream_index is specified and != 0 (WAV only has one stream) +// - sample_rate is specified and doesn't match the file's sample rate +// - num_channels is specified and doesn't match the file's channel count +std::optional validateAndDecodeWavFromTensor( + const torch::Tensor& data, + std::optional stream_index = std::nullopt, + std::optional sample_rate = std::nullopt, + std::optional num_channels = std::nullopt); -// Decode WAV from file path (uses mmap for zero-copy access). -// Returns nullopt if the file is not a valid/supported WAV file. -std::optional decodeWavFromFile(const std::string& path); +// Validate parameters and decode WAV from file path. +// Returns nullopt if: +// - The file is not a valid/supported WAV file +// - stream_index is specified and != 0 (WAV only has one stream) +// - sample_rate is specified and doesn't match the file's sample rate +// - num_channels is specified and doesn't match the file's channel count +std::optional validateAndDecodeWavFromFile( + const std::string& path, + std::optional stream_index = std::nullopt, + std::optional sample_rate = std::nullopt, + std::optional num_channels = std::nullopt); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 1109c372d..efd40c6b4 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -24,8 +24,6 @@ create_from_file, create_from_file_like, create_from_tensor, - decode_wav_from_file, - decode_wav_from_tensor, encode_audio_to_file, encode_audio_to_file_like, encode_audio_to_tensor, @@ -45,4 +43,6 @@ get_next_frame, scan_all_streams_to_update_metadata, seek_to_pts, + validate_and_decode_wav_from_file, + validate_and_decode_wav_from_tensor, ) diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index dffbad9f6..38e38744f 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -80,8 +80,10 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool"); m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); - m.def("decode_wav_from_tensor(Tensor data) -> (Tensor, str)"); - m.def("decode_wav_from_file(str path) -> (Tensor, str)"); + m.def( + "validate_and_decode_wav_from_tensor(Tensor data, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> (Tensor, str)"); + m.def( + "validate_and_decode_wav_from_file(str path, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> (Tensor, str)"); } namespace { @@ -1057,9 +1059,13 @@ void scan_all_streams_to_update_metadata(torch::Tensor& decoder) { // Slim WAV decode functions - bypass SingleStreamDecoder for direct PCM access // Returns (samples, metadata_json) or (empty tensor, "") if not a valid WAV -std::tuple decode_wav_from_tensor( - const torch::Tensor& data) { - auto result = decodeWavFromTensor(data); +std::tuple validate_and_decode_wav_from_tensor( + const torch::Tensor& data, + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) { + auto result = validateAndDecodeWavFromTensor( + data, stream_index, sample_rate, num_channels); if (!result) { return std::make_tuple( torch::empty({0, 0}, torch::kFloat32), std::string()); @@ -1067,9 +1073,13 @@ std::tuple decode_wav_from_tensor( return std::make_tuple(result->samples, result->metadataJson); } -std::tuple decode_wav_from_file( - const std::string& path) { - auto result = decodeWavFromFile(path); +std::tuple validate_and_decode_wav_from_file( + const std::string& path, + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) { + auto result = validateAndDecodeWavFromFile( + path, stream_index, sample_rate, num_channels); if (!result) { return std::make_tuple( torch::empty({0, 0}, torch::kFloat32), std::string()); @@ -1086,7 +1096,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("encode_video_to_file", &encode_video_to_file); m.impl("encode_video_to_tensor", &encode_video_to_tensor); m.impl("_encode_video_to_file_like", &_encode_video_to_file_like); - m.impl("decode_wav_from_file", &decode_wav_from_file); + m.impl( + "validate_and_decode_wav_from_file", &validate_and_decode_wav_from_file); } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { @@ -1118,7 +1129,9 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { &scan_all_streams_to_update_metadata); m.impl("_get_backend_details", &get_backend_details); - m.impl("decode_wav_from_tensor", &decode_wav_from_tensor); + m.impl( + "validate_and_decode_wav_from_tensor", + &validate_and_decode_wav_from_tensor); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 2e0fbc4fd..5a28aacce 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -168,8 +168,12 @@ def expose_ffmpeg_dlls(): # noqa: F811 torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default ) _get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default -decode_wav_from_tensor = torch.ops.torchcodec_ns.decode_wav_from_tensor.default -decode_wav_from_file = torch.ops.torchcodec_ns.decode_wav_from_file.default +validate_and_decode_wav_from_tensor = ( + torch.ops.torchcodec_ns.validate_and_decode_wav_from_tensor.default +) +validate_and_decode_wav_from_file = ( + torch.ops.torchcodec_ns.validate_and_decode_wav_from_file.default +) # ============================= diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index 4b998c24a..19f4cb690 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -64,23 +64,24 @@ def __init__( ): torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder") - # Try WAV fast path: only when no resampling/channel conversion needed + # Try WAV fast path self._wav_samples: AudioSamples | None = None - if stream_index is None and sample_rate is None and num_channels is None: - samples, metadata_json = self._try_decode_wav(source) - if metadata_json: - metadata = json.loads(metadata_json) - self._wav_samples = AudioSamples( - data=samples, - pts_seconds=0.0, - duration_seconds=metadata["durationSeconds"], - sample_rate=metadata["sampleRate"], - ) - self.stream_index = 0 - self._desired_sample_rate = metadata["sampleRate"] - self._decoder = None # type: ignore[assignment] - self.metadata = AudioStreamMetadata.from_json(metadata) - return + samples, metadata_json = self._decode_wav( + source, stream_index, sample_rate, num_channels + ) + if metadata_json: + metadata = json.loads(metadata_json) + self._wav_samples = AudioSamples( + data=samples, + pts_seconds=0.0, + duration_seconds=metadata["durationSeconds"], + sample_rate=metadata["sampleRate"], + ) + self.stream_index = 0 + self._desired_sample_rate = metadata["sampleRate"] + self._decoder = None # type: ignore[assignment] + self.metadata = AudioStreamMetadata.from_json(metadata) + return self._decoder = create_decoder(source=source, seek_mode="approximate") @@ -118,31 +119,50 @@ def __init__( ) @staticmethod - def _try_decode_wav( + def _decode_wav( source: str | Path | io.RawIOBase | io.BufferedReader | bytes | Tensor, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, ) -> tuple[Tensor, str]: - """Try decoding as WAV. Returns (samples, metadata_json). + """Decode WAV if valid and parameters match. Returns (samples, metadata_json). - Empty metadata_json means not a valid WAV file. + Empty metadata_json means not a valid WAV file or parameters don't match. """ if isinstance(source, Tensor): - return core.decode_wav_from_tensor(source) + return core.validate_and_decode_wav_from_tensor( + source, + stream_index=stream_index, + sample_rate=sample_rate, + num_channels=num_channels, + ) elif isinstance(source, bytes): - return core.decode_wav_from_tensor( - torch.frombuffer(source, dtype=torch.uint8) + return core.validate_and_decode_wav_from_tensor( + torch.frombuffer(source, dtype=torch.uint8), + stream_index=stream_index, + sample_rate=sample_rate, + num_channels=num_channels, ) elif isinstance(source, (str, Path)): path = str(source) if path.startswith(("http://", "https://", "s3://")): return torch.empty(0), "" - return core.decode_wav_from_file(path) + return core.validate_and_decode_wav_from_file( + path, + stream_index=stream_index, + sample_rate=sample_rate, + num_channels=num_channels, + ) elif isinstance(source, (io.RawIOBase, io.BufferedReader)) or ( hasattr(source, "read") and hasattr(source, "seek") ): data = source.read() source.seek(0) - return core.decode_wav_from_tensor( - torch.frombuffer(data, dtype=torch.uint8) + return core.validate_and_decode_wav_from_tensor( + torch.frombuffer(data, dtype=torch.uint8), + stream_index=stream_index, + sample_rate=sample_rate, + num_channels=num_channels, ) return torch.empty(0), "" From 70f774adb1f45c275bade42c194a7f0bbf66116b Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Wed, 4 Feb 2026 13:27:42 -0500 Subject: [PATCH 3/7] reuse mapToJson for metadata --- src/torchcodec/_core/WavDecoder.cpp | 46 +++-------------------- src/torchcodec/_core/WavDecoder.h | 4 +- src/torchcodec/_core/custom_ops.cpp | 30 ++++++++++----- src/torchcodec/decoders/_audio_decoder.py | 3 +- 4 files changed, 30 insertions(+), 53 deletions(-) diff --git a/src/torchcodec/_core/WavDecoder.cpp b/src/torchcodec/_core/WavDecoder.cpp index a95b90dee..470dffe10 100644 --- a/src/torchcodec/_core/WavDecoder.cpp +++ b/src/torchcodec/_core/WavDecoder.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include namespace facebook::torchcodec { @@ -261,43 +260,6 @@ torch::Tensor convertPcmToFloat32( TORCH_CHECK(false, "Unsupported PCM format"); } -// Build JSON metadata string compatible with AudioStreamMetadata -std::string buildWavMetadataJson(const WavHeaderInfo& header) { - double durationSeconds = - static_cast(header.numSamples) / header.sampleRate; - double bitRate = static_cast( - header.sampleRate * header.numChannels * header.bitsPerSample); - - // Determine sample format string (FFmpeg style) - std::string sampleFormat; - if (header.formatCode == WAVE_FORMAT_IEEE_FLOAT) { - sampleFormat = (header.bitsPerSample == 32) ? "flt" : "dbl"; - } else { - // PCM integer formats - if (header.bitsPerSample == 8) { - sampleFormat = "u8"; - } else { - sampleFormat = "s" + std::to_string(header.bitsPerSample); - } - } - - std::stringstream ss; - ss << "{\n"; - ss << "\"durationSecondsFromHeader\": " << durationSeconds << ",\n"; - ss << "\"durationSeconds\": " << durationSeconds << ",\n"; - ss << "\"beginStreamSecondsFromHeader\": 0.0,\n"; - ss << "\"beginStreamSeconds\": 0.0,\n"; - ss << "\"bitRate\": " << bitRate << ",\n"; - ss << "\"codec\": \"pcm\",\n"; - ss << "\"sampleRate\": " << header.sampleRate << ",\n"; - ss << "\"numChannels\": " << header.numChannels << ",\n"; - ss << "\"sampleFormat\": \"" << sampleFormat << "\",\n"; - ss << "\"mediaType\": \"audio\"\n"; - ss << "}"; - - return ss.str(); -} - // Validate optional parameters against WAV header // Returns true if parameters are compatible, false otherwise bool validateWavParams( @@ -356,7 +318,9 @@ std::optional validateAndDecodeWavFromTensor( header->formatCode, header->bitsPerSample); - return WavSamples{samples, buildWavMetadataJson(*header)}; + double durationSeconds = + static_cast(header->numSamples) / header->sampleRate; + return WavSamples{samples, header->sampleRate, durationSeconds}; } std::optional validateAndDecodeWavFromFile( @@ -389,7 +353,9 @@ std::optional validateAndDecodeWavFromFile( header->formatCode, header->bitsPerSample); - return WavSamples{samples, buildWavMetadataJson(*header)}; + double durationSeconds = + static_cast(header->numSamples) / header->sampleRate; + return WavSamples{samples, header->sampleRate, durationSeconds}; } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/WavDecoder.h b/src/torchcodec/_core/WavDecoder.h index 8e4b8b13d..18fbf2166 100644 --- a/src/torchcodec/_core/WavDecoder.h +++ b/src/torchcodec/_core/WavDecoder.h @@ -8,13 +8,13 @@ #include #include -#include namespace facebook::torchcodec { struct WavSamples { torch::Tensor samples; // Shape: (num_channels, num_samples) - std::string metadataJson; // JSON compatible with AudioStreamMetadata + int64_t sampleRate; + double durationSeconds; }; // Validate parameters and decode WAV from bytes tensor. diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 38e38744f..91317ca38 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -1057,8 +1057,20 @@ void scan_all_streams_to_update_metadata(torch::Tensor& decoder) { videoDecoder->scanFileAndUpdateMetadataAndIndex(); } -// Slim WAV decode functions - bypass SingleStreamDecoder for direct PCM access -// Returns (samples, metadata_json) or (empty tensor, "") if not a valid WAV +// Build JSON metadata for WAV samples +std::string buildWavMetadataJson(const WavSamples& wav) { + std::map map; + map["durationSecondsFromHeader"] = std::to_string(wav.durationSeconds); + map["durationSeconds"] = std::to_string(wav.durationSeconds); + map["beginStreamSecondsFromHeader"] = "0.0"; + map["beginStreamSeconds"] = "0.0"; + map["codec"] = "\"pcm\""; + map["sampleRate"] = std::to_string(wav.sampleRate); + map["numChannels"] = std::to_string(wav.samples.size(0)); + return mapToJson(map); +} + +// Returns (samples, metadata_json) or (empty tensor, "") if not valid std::tuple validate_and_decode_wav_from_tensor( const torch::Tensor& data, std::optional stream_index, @@ -1066,11 +1078,10 @@ std::tuple validate_and_decode_wav_from_tensor( std::optional num_channels) { auto result = validateAndDecodeWavFromTensor( data, stream_index, sample_rate, num_channels); - if (!result) { - return std::make_tuple( - torch::empty({0, 0}, torch::kFloat32), std::string()); + if (result) { + return std::make_tuple(result->samples, buildWavMetadataJson(*result)); } - return std::make_tuple(result->samples, result->metadataJson); + return std::make_tuple(torch::empty({0, 0}, torch::kFloat32), std::string()); } std::tuple validate_and_decode_wav_from_file( @@ -1080,11 +1091,10 @@ std::tuple validate_and_decode_wav_from_file( std::optional num_channels) { auto result = validateAndDecodeWavFromFile( path, stream_index, sample_rate, num_channels); - if (!result) { - return std::make_tuple( - torch::empty({0, 0}, torch::kFloat32), std::string()); + if (result) { + return std::make_tuple(result->samples, buildWavMetadataJson(*result)); } - return std::make_tuple(result->samples, result->metadataJson); + return std::make_tuple(torch::empty({0, 0}, torch::kFloat32), std::string()); } TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index 19f4cb690..f8012cd46 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -125,8 +125,9 @@ def _decode_wav( sample_rate: int | None = None, num_channels: int | None = None, ) -> tuple[Tensor, str]: - """Decode WAV if valid and parameters match. Returns (samples, metadata_json). + """Decode WAV if valid and parameters match. + Returns (samples, metadata_json). Empty metadata_json means not a valid WAV file or parameters don't match. """ if isinstance(source, Tensor): From d57f06adba2e2a9778688070d05e668aca574644 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Wed, 4 Feb 2026 14:44:37 -0500 Subject: [PATCH 4/7] add comments --- src/torchcodec/_core/WavDecoder.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/torchcodec/_core/WavDecoder.cpp b/src/torchcodec/_core/WavDecoder.cpp index 470dffe10..6f24f6755 100644 --- a/src/torchcodec/_core/WavDecoder.cpp +++ b/src/torchcodec/_core/WavDecoder.cpp @@ -196,8 +196,10 @@ torch::Tensor convertPcmToFloat32( case 8: { auto shape = isMono ? std::vector{numSamples} : std::vector{numSamples, numChannels}; + // Interpret raw bytes as uint8 auto uintTensor = torch::from_blob( const_cast(pcmData), shape, torch::kUInt8); + // Convert to float32, then normalize from [0, 255] to [-1, 1] auto floatTensor = uintTensor.to(torch::kFloat32).sub_(128.0f).div_(128.0f); if (isMono) { @@ -208,8 +210,10 @@ torch::Tensor convertPcmToFloat32( case 16: { auto shape = isMono ? std::vector{numSamples} : std::vector{numSamples, numChannels}; + // Interpret raw bytes as int16 auto intTensor = torch::from_blob( const_cast(pcmData), shape, torch::kInt16); + // Convert to float32, then normalize from [-32768, 32767] to [-1, 1] auto floatTensor = intTensor.to(torch::kFloat32).div_(32768.0f); if (isMono) { return floatTensor.unsqueeze(0); @@ -222,8 +226,10 @@ torch::Tensor convertPcmToFloat32( case 32: { auto shape = isMono ? std::vector{numSamples} : std::vector{numSamples, numChannels}; + // Interpret raw bytes as int32 auto intTensor = torch::from_blob( const_cast(pcmData), shape, torch::kInt32); + // Convert to float32, then normalize from [-2^31, 2^31-1] to [-1, 1] auto floatTensor = intTensor.to(torch::kFloat32).div_(2147483648.0f); if (isMono) { return floatTensor.unsqueeze(0); @@ -236,6 +242,7 @@ torch::Tensor convertPcmToFloat32( case 32: { auto shape = isMono ? std::vector{numSamples} : std::vector{numSamples, numChannels}; + // Interpret raw bytes as float32 (already normalized by convention) auto floatTensor = torch::from_blob( const_cast(pcmData), shape, torch::kFloat32); if (isMono) { @@ -246,8 +253,10 @@ torch::Tensor convertPcmToFloat32( case 64: { auto shape = isMono ? std::vector{numSamples} : std::vector{numSamples, numChannels}; + // Interpret raw bytes as float64 auto doubleTensor = torch::from_blob( const_cast(pcmData), shape, torch::kFloat64); + // Convert to float32 (already normalized by convention) auto floatTensor = doubleTensor.to(torch::kFloat32); if (isMono) { return floatTensor.unsqueeze(0); From 05664a44dd01b513b6a2160327d4055531a3fa5e Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 9 Feb 2026 16:01:24 -0500 Subject: [PATCH 5/7] refactor to not consume entire file upfront --- src/torchcodec/_core/WavDecoder.cpp | 590 ++++++++++++---------- src/torchcodec/_core/WavDecoder.h | 130 ++++- src/torchcodec/_core/__init__.py | 8 +- src/torchcodec/_core/_metadata.py | 18 + src/torchcodec/_core/custom_ops.cpp | 164 ++++-- src/torchcodec/_core/ops.py | 67 ++- src/torchcodec/decoders/_audio_decoder.py | 120 ++--- src/torchcodec/decoders/_decoder_utils.py | 102 ++++ test/test_decoders.py | 42 +- 9 files changed, 783 insertions(+), 458 deletions(-) diff --git a/src/torchcodec/_core/WavDecoder.cpp b/src/torchcodec/_core/WavDecoder.cpp index 6f24f6755..29347fb00 100644 --- a/src/torchcodec/_core/WavDecoder.cpp +++ b/src/torchcodec/_core/WavDecoder.cpp @@ -6,365 +6,401 @@ #include "WavDecoder.h" -#include +#include +#include #include -#include -#include +#include namespace facebook::torchcodec { namespace { -// PCM format codes in WAV files -constexpr uint16_t WAVE_FORMAT_PCM = 1; -constexpr uint16_t WAVE_FORMAT_IEEE_FLOAT = 3; - -// Read a little-endian value from raw bytes template -T readLE(const uint8_t* data) { +T readLittleEndian(const uint8_t* data) { T value; std::memcpy(&value, data, sizeof(T)); return value; } -// Check for a 4-byte identifier (FOURCC) at a given offset -bool checkFourCC( - const uint8_t* data, - int64_t size, - int64_t offset, - const char* expected) { - if (offset + 4 > size) { - return false; +bool checkFourCC(const uint8_t* data, const char* expected) { + return std::memcmp(data, expected, 4) == 0; +} + +} // namespace + +// WavFileReader implementation +WavFileReader::WavFileReader(const std::string& path) : file_(nullptr) { + file_ = std::fopen(path.c_str(), "rb"); + if (!file_) { + throw std::runtime_error("Failed to open WAV file: " + path); } - return std::memcmp(data + offset, expected, 4) == 0; } -// WAV header info extracted during parsing -struct WavHeaderInfo { - int64_t dataOffset; // Byte offset to PCM data - int64_t dataSize; // Size of PCM data in bytes - int64_t numSamples; // Total samples per channel - int sampleRate; - int numChannels; - int bitsPerSample; - uint16_t formatCode; // 1=PCM int, 3=IEEE float -}; - -// Read entire file into a vector -std::optional> readFile(const std::string& path) { - std::ifstream file(path, std::ios::binary | std::ios::ate); - if (!file) { - return std::nullopt; +WavFileReader::~WavFileReader() { + if (file_) { + std::fclose(file_); } +} - auto size = file.tellg(); - if (size <= 0) { - return std::nullopt; +int64_t WavFileReader::read(void* buffer, int64_t size) { + if (!file_) { + return -1; } + size_t bytesRead = std::fread(buffer, 1, static_cast(size), file_); + return static_cast(bytesRead); +} - std::vector data(static_cast(size)); - file.seekg(0, std::ios::beg); - if (!file.read(reinterpret_cast(data.data()), size)) { - return std::nullopt; +int64_t WavFileReader::seek(int64_t position) { + if (!file_) { + return -1; + } + if (std::fseek(file_, static_cast(position), SEEK_SET) != 0) { + return -1; } + return position; +} - return data; +// WavTensorReader implementation +WavTensorReader::WavTensorReader(const torch::Tensor& data) + : data_(data), currentPos_(0) { + TORCH_CHECK(data.is_contiguous(), "WAV data tensor must be contiguous"); + TORCH_CHECK( + data.scalar_type() == torch::kUInt8, "WAV data tensor must be uint8"); } -// Parse WAV header from raw bytes -std::optional parseWavHeader(const uint8_t* data, int64_t size) { - // Need at least 44 bytes for minimal WAV header - if (size < 44) { - return std::nullopt; +int64_t WavTensorReader::read(void* buffer, int64_t size) { + int64_t available = data_.numel() - currentPos_; + int64_t toRead = std::min(size, available); + if (toRead <= 0) { + return 0; } - // Check RIFF/WAVE signature - if (!checkFourCC(data, size, 0, "RIFF") || - !checkFourCC(data, size, 8, "WAVE")) { - return std::nullopt; + const uint8_t* src = data_.data_ptr() + currentPos_; + std::memcpy(buffer, src, static_cast(toRead)); + currentPos_ += toRead; + return toRead; +} + +int64_t WavTensorReader::seek(int64_t position) { + if (position < 0 || position > data_.numel()) { + return -1; } + currentPos_ = position; + return currentPos_; +} - // Parse chunks to find fmt and data - int64_t pos = 12; - int numChannels = 0; - int sampleRate = 0; - int bitsPerSample = 0; - uint16_t formatCode = 0; - int64_t dataOffset = 0; - int64_t dataSize = 0; - bool foundFmt = false; - bool foundData = false; +// WavDecoder implementation +WavDecoder::WavDecoder(std::unique_ptr reader) + : reader_(std::move(reader)) { + parseHeader(); +} - while (pos + 8 <= size) { - uint32_t chunkSize = readLE(data + pos + 4); +bool WavDecoder::isWavFile(const void* data, size_t size) { + if (size < 12) { + return false; + } + const uint8_t* bytes = static_cast(data); + // Check for RIFF....WAVE + return checkFourCC(bytes, "RIFF") && checkFourCC(bytes + 8, "WAVE"); +} - if (checkFourCC(data, size, pos, "fmt ")) { - if (pos + 8 + 16 > size) { - return std::nullopt; - } - const uint8_t* fmt = data + pos + 8; - formatCode = readLE(fmt); - numChannels = readLE(fmt + 2); - sampleRate = readLE(fmt + 4); - bitsPerSample = readLE(fmt + 14); - foundFmt = true; - } else if (checkFourCC(data, size, pos, "data")) { - dataOffset = pos + 8; - dataSize = chunkSize; - foundData = true; - break; - } +void WavDecoder::parseHeader() { + // Read enough for header parsing (typical WAV headers are < 100 bytes) + // TODO: source? + constexpr int64_t headerBufferSize = 256; + std::vector buffer(headerBufferSize); - pos += 8 + chunkSize + (chunkSize & 1); + reader_->seek(0); + int64_t bytesRead = reader_->read(buffer.data(), headerBufferSize); + if (bytesRead < 44) { + throw std::runtime_error("WAV data too small to contain valid header"); } - if (!foundFmt || !foundData) { - return std::nullopt; - } + const uint8_t* data = buffer.data(); - // Validate basic parameters - if (numChannels <= 0 || sampleRate <= 0 || bitsPerSample <= 0) { - return std::nullopt; + // Verify RIFF header + if (!checkFourCC(data, "RIFF")) { + throw std::runtime_error("Missing RIFF header"); } - // Validate format/bitsPerSample combinations - if (formatCode == WAVE_FORMAT_PCM) { - if (bitsPerSample != 8 && bitsPerSample != 16 && bitsPerSample != 24 && - bitsPerSample != 32) { - return std::nullopt; - } - } else if (formatCode == WAVE_FORMAT_IEEE_FLOAT) { - if (bitsPerSample != 32 && bitsPerSample != 64) { - return std::nullopt; - } - } else { - // Unsupported format (extensible, etc.) - return std::nullopt; + // Verify WAVE format + if (!checkFourCC(data + 8, "WAVE")) { + throw std::runtime_error("Missing WAVE format identifier"); } - int bytesPerSample = bitsPerSample / 8; - int64_t bytesPerFrame = numChannels * bytesPerSample; - int64_t numSamples = dataSize / bytesPerFrame; - - return WavHeaderInfo{ - dataOffset, - dataSize, - numSamples, - sampleRate, - numChannels, - bitsPerSample, - formatCode}; -} + // Find and parse fmt chunk + int64_t offset = 12; + bool foundFmt = false; -// Convert 24-bit PCM samples to float32 tensor -torch::Tensor convert24BitPcmToFloat32( - const uint8_t* pcmData, - int64_t numSamples, - int numChannels) { - auto output = torch::empty({numChannels, numSamples}, torch::kFloat32); - float* outPtr = output.data_ptr(); + while (offset + 8 <= bytesRead) { + if (checkFourCC(data + offset, "fmt ")) { + uint32_t fmtSize = readLittleEndian(data + offset + 4); - constexpr float scale = 1.0f / 8388608.0f; + if (offset + 8 + fmtSize > bytesRead) { + throw std::runtime_error("fmt chunk extends beyond buffer"); + } - for (int64_t s = 0; s < numSamples; ++s) { - for (int c = 0; c < numChannels; ++c) { - size_t byteIdx = (s * numChannels + c) * 3; - int32_t sample = pcmData[byteIdx] | (pcmData[byteIdx + 1] << 8) | - (pcmData[byteIdx + 2] << 16); - // Sign extend from 24 to 32 bits - if (sample & 0x800000) { - sample |= 0xFF000000; + if (fmtSize < 16) { + throw std::runtime_error("fmt chunk too small"); } - outPtr[c * numSamples + s] = sample * scale; + + const uint8_t* fmtData = data + offset + 8; + // TODO: explain https://en.wikipedia.org/wiki/WAV#WAV_file_header + header_.audioFormat = readLittleEndian(fmtData); + header_.numChannels = readLittleEndian(fmtData + 2); + header_.sampleRate = readLittleEndian(fmtData + 4); + header_.byteRate = readLittleEndian(fmtData + 8); + header_.blockAlign = readLittleEndian(fmtData + 12); + header_.bitsPerSample = readLittleEndian(fmtData + 14); + + // Parse extended format fields for WAVE_FORMAT_EXTENSIBLE + if (header_.audioFormat == WAV_FORMAT_EXTENSIBLE) { + // Extended format requires at least 40 bytes total (16 base + 2 cbSize + // + 22 extension) + if (fmtSize < 40) { + throw std::runtime_error( + "WAVE_FORMAT_EXTENSIBLE fmt chunk too small"); + } + + header_.validBitsPerSample = readLittleEndian(fmtData + 18); + header_.channelMask = readLittleEndian(fmtData + 20); + // SubFormat GUID starts at offset 24, first 2 bytes are the format code + header_.subFormat = readLittleEndian(fmtData + 24); + } + + foundFmt = true; + offset += 8 + fmtSize; + break; } + // Skip unknown chunks + uint32_t chunkSize = readLittleEndian(data + offset + 4); + offset += 8 + chunkSize; } - return output; + if (!foundFmt) { + throw std::runtime_error("fmt chunk not found"); + } + + while (offset + 8 <= bytesRead) { + if (checkFourCC(data + offset, "data")) { + // Parse data chunk + header_.dataSize = readLittleEndian(data + offset + 4); + header_.dataOffset = offset + 8; + return; + } + + // Skip this chunk + uint32_t chunkSize = readLittleEndian(data + offset + 4); + offset += 8 + chunkSize; + } + + throw std::runtime_error("data chunk not found"); } -// Convert PCM data to float32 tensor -// Returns tensor of shape (numChannels, numSamples) -torch::Tensor convertPcmToFloat32( - const uint8_t* pcmData, - int64_t numSamples, - int numChannels, - uint16_t formatCode, - int bitsPerSample) { - const bool isMono = (numChannels == 1); - - if (formatCode == WAVE_FORMAT_PCM) { - switch (bitsPerSample) { - case 8: { - auto shape = isMono ? std::vector{numSamples} - : std::vector{numSamples, numChannels}; - // Interpret raw bytes as uint8 - auto uintTensor = torch::from_blob( - const_cast(pcmData), shape, torch::kUInt8); - // Convert to float32, then normalize from [0, 255] to [-1, 1] - auto floatTensor = - uintTensor.to(torch::kFloat32).sub_(128.0f).div_(128.0f); - if (isMono) { - return floatTensor.unsqueeze(0); - } - return floatTensor.t().contiguous(); - } - case 16: { - auto shape = isMono ? std::vector{numSamples} - : std::vector{numSamples, numChannels}; - // Interpret raw bytes as int16 - auto intTensor = torch::from_blob( - const_cast(pcmData), shape, torch::kInt16); - // Convert to float32, then normalize from [-32768, 32767] to [-1, 1] - auto floatTensor = intTensor.to(torch::kFloat32).div_(32768.0f); - if (isMono) { - return floatTensor.unsqueeze(0); - } - return floatTensor.t().contiguous(); - } - case 24: { - return convert24BitPcmToFloat32(pcmData, numSamples, numChannels); - } - case 32: { - auto shape = isMono ? std::vector{numSamples} - : std::vector{numSamples, numChannels}; - // Interpret raw bytes as int32 - auto intTensor = torch::from_blob( - const_cast(pcmData), shape, torch::kInt32); - // Convert to float32, then normalize from [-2^31, 2^31-1] to [-1, 1] - auto floatTensor = intTensor.to(torch::kFloat32).div_(2147483648.0f); - if (isMono) { - return floatTensor.unsqueeze(0); - } - return floatTensor.t().contiguous(); - } +bool WavDecoder::isSupported() const { + // Determine effective format (subFormat for extensible, audioFormat + // otherwise) + uint16_t effectiveFormat = header_.audioFormat; + if (header_.audioFormat == WAV_FORMAT_EXTENSIBLE) { + effectiveFormat = header_.subFormat; + } + + // Support PCM and IEEE float formats + if (effectiveFormat != WAV_FORMAT_PCM && + effectiveFormat != WAV_FORMAT_IEEE_FLOAT) { + return false; + } + + // Validate bits per sample + if (effectiveFormat == WAV_FORMAT_PCM) { + if (header_.bitsPerSample != 8 && header_.bitsPerSample != 16 && + header_.bitsPerSample != 24 && header_.bitsPerSample != 32) { + return false; } - } else if (formatCode == WAVE_FORMAT_IEEE_FLOAT) { - switch (bitsPerSample) { - case 32: { - auto shape = isMono ? std::vector{numSamples} - : std::vector{numSamples, numChannels}; - // Interpret raw bytes as float32 (already normalized by convention) - auto floatTensor = torch::from_blob( - const_cast(pcmData), shape, torch::kFloat32); - if (isMono) { - return floatTensor.clone().unsqueeze(0); - } - return floatTensor.t().contiguous(); - } - case 64: { - auto shape = isMono ? std::vector{numSamples} - : std::vector{numSamples, numChannels}; - // Interpret raw bytes as float64 - auto doubleTensor = torch::from_blob( - const_cast(pcmData), shape, torch::kFloat64); - // Convert to float32 (already normalized by convention) - auto floatTensor = doubleTensor.to(torch::kFloat32); - if (isMono) { - return floatTensor.unsqueeze(0); - } - return floatTensor.t().contiguous(); - } + } else if (effectiveFormat == WAV_FORMAT_IEEE_FLOAT) { + if (header_.bitsPerSample != 32 && header_.bitsPerSample != 64) { + return false; } } - TORCH_CHECK(false, "Unsupported PCM format"); + return header_.numChannels > 0 && header_.sampleRate > 0 && + header_.blockAlign > 0; } -// Validate optional parameters against WAV header -// Returns true if parameters are compatible, false otherwise -bool validateWavParams( - const WavHeaderInfo& header, +bool WavDecoder::isCompatible( std::optional stream_index, std::optional sample_rate, - std::optional num_channels) { + std::optional num_channels) const { // WAV files only have one stream at index 0 if (stream_index.has_value() && stream_index.value() != 0) { return false; } - // Check sample rate matches if specified - if (sample_rate.has_value() && sample_rate.value() != header.sampleRate) { + // Check sample rate matches if specified (no resampling support) + if (sample_rate.has_value() && + sample_rate.value() != static_cast(header_.sampleRate)) { return false; } - // Check channel count matches if specified - if (num_channels.has_value() && num_channels.value() != header.numChannels) { + // Check channel count matches if specified (no remixing support) + if (num_channels.has_value() && + num_channels.value() != static_cast(header_.numChannels)) { return false; } return true; } -} // namespace +const WavHeader& WavDecoder::getHeader() const { + return header_; +} -std::optional validateAndDecodeWavFromTensor( - const torch::Tensor& data, - std::optional stream_index, - std::optional sample_rate, - std::optional num_channels) { - TORCH_CHECK( - data.is_contiguous() && data.dtype() == torch::kUInt8, - "Input tensor must be contiguous uint8"); +double WavDecoder::getDurationSeconds() const { + if (header_.blockAlign == 0 || header_.sampleRate == 0) { + return 0.0; + } + int64_t numSamples = + static_cast(header_.dataSize) / header_.blockAlign; + return static_cast(numSamples) / header_.sampleRate; +} - const uint8_t* ptr = data.data_ptr(); - int64_t size = data.numel(); +torch::Tensor WavDecoder::convertSamplesToFloat( + const void* rawData, + int64_t numSamples, + int64_t numChannels) { + // Output is (numChannels, numSamples) float32 + torch::Tensor output = + torch::empty({numChannels, numSamples}, torch::kFloat32); + float* outPtr = output.data_ptr(); - auto header = parseWavHeader(ptr, size); - if (!header) { - return std::nullopt; - } + const uint8_t* src = static_cast(rawData); + int bytesPerSample = header_.bitsPerSample / 8; - // Validate optional parameters - if (!validateWavParams(*header, stream_index, sample_rate, num_channels)) { - return std::nullopt; + // Determine effective format (subFormat for extensible, audioFormat + // otherwise) + uint16_t effectiveFormat = header_.audioFormat; + if (header_.audioFormat == WAV_FORMAT_EXTENSIBLE) { + effectiveFormat = header_.subFormat; } - // Validate data bounds - if (header->dataOffset + header->dataSize > size) { - return std::nullopt; + if (effectiveFormat == WAV_FORMAT_IEEE_FLOAT) { + if (header_.bitsPerSample == 32) { + // 32-bit float - just copy and deinterleave + const float* floatSrc = reinterpret_cast(src); + for (int64_t s = 0; s < numSamples; ++s) { + for (int64_t c = 0; c < numChannels; ++c) { + outPtr[c * numSamples + s] = floatSrc[s * numChannels + c]; + } + } + } else if (header_.bitsPerSample == 64) { + // 64-bit float - convert to 32-bit and deinterleave + const double* doubleSrc = reinterpret_cast(src); + for (int64_t s = 0; s < numSamples; ++s) { + for (int64_t c = 0; c < numChannels; ++c) { + outPtr[c * numSamples + s] = + static_cast(doubleSrc[s * numChannels + c]); + } + } + } + } else { + // PCM format - convert to normalized float + for (int64_t s = 0; s < numSamples; ++s) { + for (int64_t c = 0; c < numChannels; ++c) { + const uint8_t* samplePtr = src + (s * numChannels + c) * bytesPerSample; + float value = 0.0f; + + switch (header_.bitsPerSample) { + case 8: { + // 8-bit PCM is unsigned (0-255, center at 128) + uint8_t sample = *samplePtr; + value = (static_cast(sample) - 128.0f) / 128.0f; + break; + } + case 16: { + // 16-bit PCM is signed + int16_t sample = readLittleEndian(samplePtr); + value = static_cast(sample) / 32768.0f; + break; + } + case 24: { + // 24-bit PCM is signed, stored in 3 bytes little-endian + int32_t sample = static_cast(samplePtr[0]) | + (static_cast(samplePtr[1]) << 8) | + (static_cast(samplePtr[2]) << 16); + // Sign extend from 24 to 32 bits + if (sample & 0x800000) { + sample |= 0xFF000000; + } + value = static_cast(sample) / 8388608.0f; + break; + } + case 32: { + // 32-bit PCM is signed + int32_t sample = readLittleEndian(samplePtr); + value = static_cast(sample) / 2147483648.0f; + break; + } + } + outPtr[c * numSamples + s] = value; + } + } } - auto samples = convertPcmToFloat32( - ptr + header->dataOffset, - header->numSamples, - header->numChannels, - header->formatCode, - header->bitsPerSample); - - double durationSeconds = - static_cast(header->numSamples) / header->sampleRate; - return WavSamples{samples, header->sampleRate, durationSeconds}; + return output; } -std::optional validateAndDecodeWavFromFile( - const std::string& path, - std::optional stream_index, - std::optional sample_rate, - std::optional num_channels) { - auto fileData = readFile(path); - if (!fileData) { - return std::nullopt; +std::tuple WavDecoder::getSamplesInRange( + double startSeconds, + std::optional stopSeconds) { + TORCH_CHECK(startSeconds >= 0, "start_seconds must be non-negative"); + if (stopSeconds.has_value()) { + TORCH_CHECK( + stopSeconds.value() >= startSeconds, + "stop_seconds must be >= start_seconds"); } - const uint8_t* ptr = fileData->data(); - int64_t size = static_cast(fileData->size()); + double duration = getDurationSeconds(); + if (startSeconds >= duration) { + // Return empty tensor + return std::make_tuple( + torch::empty({header_.numChannels, 0}, torch::kFloat32), startSeconds); + } + + double actualStop = stopSeconds.value_or(duration); + actualStop = std::min(actualStop, duration); + + // Calculate sample range + int64_t startSample = static_cast(startSeconds * header_.sampleRate); + int64_t stopSample = static_cast(actualStop * header_.sampleRate); + int64_t numSamples = stopSample - startSample; - auto header = parseWavHeader(ptr, size); - if (!header || header->dataOffset + header->dataSize > size) { - return std::nullopt; + if (numSamples <= 0) { + return std::make_tuple( + torch::empty({header_.numChannels, 0}, torch::kFloat32), startSeconds); } - // Validate optional parameters - if (!validateWavParams(*header, stream_index, sample_rate, num_channels)) { - return std::nullopt; + // Calculate byte positions + int64_t byteOffset = startSample * header_.blockAlign; + int64_t bytesToRead = numSamples * header_.blockAlign; + + // Seek to position and read + reader_->seek(static_cast(header_.dataOffset) + byteOffset); + std::vector rawData(bytesToRead); + int64_t bytesRead = reader_->read(rawData.data(), bytesToRead); + + if (bytesRead < bytesToRead) { + // Adjust numSamples if we couldn't read everything + numSamples = bytesRead / header_.blockAlign; + if (numSamples <= 0) { + return std::make_tuple( + torch::empty({header_.numChannels, 0}, torch::kFloat32), + startSeconds); + } } - auto samples = convertPcmToFloat32( - ptr + header->dataOffset, - header->numSamples, - header->numChannels, - header->formatCode, - header->bitsPerSample); + torch::Tensor samples = + convertSamplesToFloat(rawData.data(), numSamples, header_.numChannels); + + // Calculate actual PTS + double ptsSeconds = static_cast(startSample) / header_.sampleRate; - double durationSeconds = - static_cast(header->numSamples) / header->sampleRate; - return WavSamples{samples, header->sampleRate, durationSeconds}; + return std::make_tuple(samples, ptsSeconds); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/WavDecoder.h b/src/torchcodec/_core/WavDecoder.h index 18fbf2166..58c0d6932 100644 --- a/src/torchcodec/_core/WavDecoder.h +++ b/src/torchcodec/_core/WavDecoder.h @@ -7,38 +7,114 @@ #pragma once #include +#include +#include #include +#include namespace facebook::torchcodec { -struct WavSamples { - torch::Tensor samples; // Shape: (num_channels, num_samples) - int64_t sampleRate; - double durationSeconds; +// WAV format constants +constexpr uint16_t WAV_FORMAT_PCM = 1; +constexpr uint16_t WAV_FORMAT_IEEE_FLOAT = 3; +constexpr uint16_t WAV_FORMAT_EXTENSIBLE = 0xFFFE; + +// Parsed WAV header information +struct WavHeader { + uint16_t audioFormat = 0; // 1 = PCM, 3 = IEEE float + uint16_t numChannels = 0; + uint32_t sampleRate = 0; + uint32_t byteRate = 0; + uint16_t blockAlign = 0; + uint16_t bitsPerSample = 0; + uint64_t dataOffset = 0; // Offset to start of audio data + uint64_t dataSize = 0; // Size of audio data in bytes + + // Extended format fields (WAVE_FORMAT_EXTENSIBLE) + uint16_t validBitsPerSample = 0; + uint32_t channelMask = 0; + uint16_t subFormat = 0; // Extracted from SubFormat GUID (first 2 bytes) +}; + +// Abstract base class for reading WAV data from different sources +class WavReader { + public: + virtual ~WavReader() = default; + + // Read up to `size` bytes into `buffer`. Returns bytes actually read. + virtual int64_t read(void* buffer, int64_t size) = 0; + + // Seek to absolute position. Returns new position or -1 on error. + virtual int64_t seek(int64_t position) = 0; }; -// Validate parameters and decode WAV from bytes tensor. -// Returns nullopt if: -// - The data is not a valid/supported WAV file -// - stream_index is specified and != 0 (WAV only has one stream) -// - sample_rate is specified and doesn't match the file's sample rate -// - num_channels is specified and doesn't match the file's channel count -std::optional validateAndDecodeWavFromTensor( - const torch::Tensor& data, - std::optional stream_index = std::nullopt, - std::optional sample_rate = std::nullopt, - std::optional num_channels = std::nullopt); - -// Validate parameters and decode WAV from file path. -// Returns nullopt if: -// - The file is not a valid/supported WAV file -// - stream_index is specified and != 0 (WAV only has one stream) -// - sample_rate is specified and doesn't match the file's sample rate -// - num_channels is specified and doesn't match the file's channel count -std::optional validateAndDecodeWavFromFile( - const std::string& path, - std::optional stream_index = std::nullopt, - std::optional sample_rate = std::nullopt, - std::optional num_channels = std::nullopt); +// WavReader implementation for file paths +class WavFileReader : public WavReader { + public: + explicit WavFileReader(const std::string& path); + ~WavFileReader() override; + + int64_t read(void* buffer, int64_t size) override; + int64_t seek(int64_t position) override; + + private: + std::FILE* file_; +}; + +// WavReader implementation for tensor/bytes data +class WavTensorReader : public WavReader { + public: + explicit WavTensorReader(const torch::Tensor& data); + + int64_t read(void* buffer, int64_t size) override; + int64_t seek(int64_t position) override; + + private: + torch::Tensor data_; + int64_t currentPos_; +}; + +// Main WAV decoder class +class WavDecoder { + public: + explicit WavDecoder(std::unique_ptr reader); + + // Check if this is a supported uncompressed WAV file + // Returns true for PCM and IEEE float formats + bool isSupported() const; + + // Check if the requested parameters are compatible with this WAV file. + // Returns false if resampling or channel mixing would be required. + // WAV files only have one stream, so stream_index must be 0 or nullopt. + bool isCompatible( + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) const; + + // Get the parsed header + const WavHeader& getHeader() const; + + // Get samples in a time range, returns (samples, pts_seconds) + // samples is shape (num_channels, num_samples) float32 normalized to [-1, 1] + std::tuple getSamplesInRange( + double startSeconds, + std::optional stopSeconds); + + // Get total duration in seconds + double getDurationSeconds() const; + + // Static helper to check if data looks like a WAV file + static bool isWavFile(const void* data, size_t size); + + private: + void parseHeader(); + torch::Tensor convertSamplesToFloat( + const void* rawData, + int64_t numSamples, + int64_t numChannels); + + std::unique_ptr reader_; + WavHeader header_; +}; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index efd40c6b4..8c0f357fa 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -8,12 +8,14 @@ from ._metadata import ( AudioStreamMetadata, ContainerMetadata, + create_audio_metadata_from_wav, get_container_metadata, get_container_metadata_from_header, VideoStreamMetadata, ) from .ops import ( _add_video_stream, + _decode_wav_from_file_like, _get_backend_details, _get_key_frame_indices, _test_frame_pts_equality, @@ -24,6 +26,8 @@ create_from_file, create_from_file_like, create_from_tensor, + decode_wav_from_file, + decode_wav_from_tensor, encode_audio_to_file, encode_audio_to_file_like, encode_audio_to_tensor, @@ -41,8 +45,8 @@ get_frames_in_range, get_json_metadata, get_next_frame, + get_wav_metadata_from_file, + get_wav_metadata_from_tensor, scan_all_streams_to_update_metadata, seek_to_pts, - validate_and_decode_wav_from_file, - validate_and_decode_wav_from_tensor, ) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 2c0769cff..c0200caac 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -157,6 +157,24 @@ def from_json(cls, json: dict, stream_index: int = 0) -> "AudioStreamMetadata": ) +def create_audio_metadata_from_wav(wav_json: dict) -> AudioStreamMetadata: + """Create AudioStreamMetadata from WAV metadata dict.""" + return AudioStreamMetadata( + duration_seconds_from_header=wav_json.get("duration_seconds"), + begin_stream_seconds_from_header=wav_json.get( + "begin_stream_seconds_from_header", 0.0 + ), + bit_rate=wav_json.get("bit_rate"), + codec=wav_json.get("codec", "pcm"), + stream_index=wav_json.get("stream_index", 0), + duration_seconds=wav_json.get("duration_seconds"), + begin_stream_seconds=wav_json.get("begin_stream_seconds", 0.0), + sample_rate=wav_json.get("sample_rate"), + num_channels=wav_json.get("num_channels"), + sample_format=wav_json.get("sample_format"), + ) + + @dataclass class ContainerMetadata: duration_seconds_from_header: float | None diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 91317ca38..f11433be5 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -81,9 +81,15 @@ TORCH_LIBRARY(torchcodec_ns, m) { "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool"); m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); m.def( - "validate_and_decode_wav_from_tensor(Tensor data, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> (Tensor, str)"); + "decode_wav_from_file(str filename, float start_seconds=0.0, float? stop_seconds=None) -> (Tensor, Tensor)"); m.def( - "validate_and_decode_wav_from_file(str path, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> (Tensor, str)"); + "decode_wav_from_tensor(Tensor data, float start_seconds=0.0, float? stop_seconds=None) -> (Tensor, Tensor)"); + m.def( + "_decode_wav_from_file_like(int ctx, float start_seconds=0.0, float? stop_seconds=None) -> (Tensor, Tensor)"); + m.def( + "get_wav_metadata_from_file(str filename, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> str"); + m.def( + "get_wav_metadata_from_tensor(Tensor data, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> str"); } namespace { @@ -1057,44 +1063,142 @@ void scan_all_streams_to_update_metadata(torch::Tensor& decoder) { videoDecoder->scanFileAndUpdateMetadataAndIndex(); } -// Build JSON metadata for WAV samples -std::string buildWavMetadataJson(const WavSamples& wav) { +// The elements of this tuple are: +// 1. The audio samples as a float32 tensor of shape (num_channels, +// num_samples) +// 2. A single float value for the pts of the first sample, in seconds. +using OpsWavOutput = std::tuple; + +OpsWavOutput decode_wav_from_file( + std::string_view filename, + double start_seconds, + std::optional stop_seconds) { + auto reader = std::make_unique(std::string(filename)); + WavDecoder decoder(std::move(reader)); + + TORCH_CHECK( + decoder.isSupported(), + "Unsupported WAV format. Only PCM and IEEE float formats are supported."); + + auto [samples, pts] = decoder.getSamplesInRange(start_seconds, stop_seconds); + return std::make_tuple( + samples, torch::tensor(pts, torch::dtype(torch::kFloat64))); +} + +OpsWavOutput decode_wav_from_tensor( + const torch::Tensor& data, + double start_seconds, + std::optional stop_seconds) { + auto reader = std::make_unique(data); + WavDecoder decoder(std::move(reader)); + + TORCH_CHECK( + decoder.isSupported(), + "Unsupported WAV format. Only PCM and IEEE float formats are supported."); + + auto [samples, pts] = decoder.getSamplesInRange(start_seconds, stop_seconds); + return std::make_tuple( + samples, torch::tensor(pts, torch::dtype(torch::kFloat64))); +} + +OpsWavOutput _decode_wav_from_file_like( + int64_t file_like_context, + double start_seconds, + std::optional stop_seconds) { + // Get the file-like object and read all data into a tensor + auto fileLikeContext = + reinterpret_cast(file_like_context); + TORCH_CHECK( + fileLikeContext != nullptr, "file_like_context must be a valid pointer"); + + // Read all data from file-like object + // First, try to get size by seeking to end + fileLikeContext->getAVIOContext()->seek( + fileLikeContext->getAVIOContext(), 0, SEEK_END); + int64_t totalSize = fileLikeContext->getAVIOContext()->pos; + fileLikeContext->getAVIOContext()->seek( + fileLikeContext->getAVIOContext(), 0, SEEK_SET); + + std::vector buffer(totalSize); + int64_t bytesRead = 0; + while (bytesRead < totalSize) { + int ret = fileLikeContext->getAVIOContext()->read_packet( + fileLikeContext->getAVIOContext()->opaque, + buffer.data() + bytesRead, + totalSize - bytesRead); + if (ret <= 0) { + break; + } + bytesRead += ret; + } + + // Create tensor from buffer + torch::Tensor tensorData = + torch::from_blob( + buffer.data(), {static_cast(buffer.size())}, torch::kUInt8) + .clone(); + + auto reader = std::make_unique(tensorData); + WavDecoder decoder(std::move(reader)); + + TORCH_CHECK( + decoder.isSupported(), + "Unsupported WAV format. Only PCM and IEEE float formats are supported."); + + auto [samples, pts] = decoder.getSamplesInRange(start_seconds, stop_seconds); + return std::make_tuple( + samples, torch::tensor(pts, torch::dtype(torch::kFloat64))); +} + +std::string buildWavMetadataJson(WavDecoder& decoder) { + const WavHeader& header = decoder.getHeader(); std::map map; - map["durationSecondsFromHeader"] = std::to_string(wav.durationSeconds); - map["durationSeconds"] = std::to_string(wav.durationSeconds); - map["beginStreamSecondsFromHeader"] = "0.0"; - map["beginStreamSeconds"] = "0.0"; + // Fields matching AudioStreamMetadata + map["sample_rate"] = std::to_string(header.sampleRate); + map["num_channels"] = std::to_string(header.numChannels); + map["duration_seconds"] = fmt::to_string(decoder.getDurationSeconds()); + map["duration_seconds_from_header"] = + fmt::to_string(decoder.getDurationSeconds()); + map["begin_stream_seconds"] = "0.0"; + map["begin_stream_seconds_from_header"] = "0.0"; + map["bit_rate"] = "null"; map["codec"] = "\"pcm\""; - map["sampleRate"] = std::to_string(wav.sampleRate); - map["numChannels"] = std::to_string(wav.samples.size(0)); + map["stream_index"] = "0"; + map["sample_format"] = "null"; + return mapToJson(map); } -// Returns (samples, metadata_json) or (empty tensor, "") if not valid -std::tuple validate_and_decode_wav_from_tensor( - const torch::Tensor& data, +std::string get_wav_metadata_from_file( + std::string_view filename, std::optional stream_index, std::optional sample_rate, std::optional num_channels) { - auto result = validateAndDecodeWavFromTensor( - data, stream_index, sample_rate, num_channels); - if (result) { - return std::make_tuple(result->samples, buildWavMetadataJson(*result)); + auto reader = std::make_unique(std::string(filename)); + WavDecoder decoder(std::move(reader)); + + if (!decoder.isSupported() || + !decoder.isCompatible(stream_index, sample_rate, num_channels)) { + return ""; } - return std::make_tuple(torch::empty({0, 0}, torch::kFloat32), std::string()); + + return buildWavMetadataJson(decoder); } -std::tuple validate_and_decode_wav_from_file( - const std::string& path, +std::string get_wav_metadata_from_tensor( + const torch::Tensor& data, std::optional stream_index, std::optional sample_rate, std::optional num_channels) { - auto result = validateAndDecodeWavFromFile( - path, stream_index, sample_rate, num_channels); - if (result) { - return std::make_tuple(result->samples, buildWavMetadataJson(*result)); + auto reader = std::make_unique(data); + WavDecoder decoder(std::move(reader)); + + if (!decoder.isSupported() || + !decoder.isCompatible(stream_index, sample_rate, num_channels)) { + return ""; } - return std::make_tuple(torch::empty({0, 0}, torch::kFloat32), std::string()); + + return buildWavMetadataJson(decoder); } TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { @@ -1106,8 +1210,11 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("encode_video_to_file", &encode_video_to_file); m.impl("encode_video_to_tensor", &encode_video_to_tensor); m.impl("_encode_video_to_file_like", &_encode_video_to_file_like); - m.impl( - "validate_and_decode_wav_from_file", &validate_and_decode_wav_from_file); + m.impl("decode_wav_from_file", &decode_wav_from_file); + m.impl("decode_wav_from_tensor", &decode_wav_from_tensor); + m.impl("_decode_wav_from_file_like", &_decode_wav_from_file_like); + m.impl("get_wav_metadata_from_file", &get_wav_metadata_from_file); + m.impl("get_wav_metadata_from_tensor", &get_wav_metadata_from_tensor); } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { @@ -1139,9 +1246,6 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { &scan_all_streams_to_update_metadata); m.impl("_get_backend_details", &get_backend_details); - m.impl( - "validate_and_decode_wav_from_tensor", - &validate_and_decode_wav_from_tensor); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 5a28aacce..733ac02ac 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -168,11 +168,20 @@ def expose_ffmpeg_dlls(): # noqa: F811 torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default ) _get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default -validate_and_decode_wav_from_tensor = ( - torch.ops.torchcodec_ns.validate_and_decode_wav_from_tensor.default +decode_wav_from_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.decode_wav_from_file.default ) -validate_and_decode_wav_from_file = ( - torch.ops.torchcodec_ns.validate_and_decode_wav_from_file.default +decode_wav_from_tensor = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.decode_wav_from_tensor.default +) +_decode_wav_from_file_like = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns._decode_wav_from_file_like.default +) +get_wav_metadata_from_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.get_wav_metadata_from_file.default +) +get_wav_metadata_from_tensor = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.get_wav_metadata_from_tensor.default ) @@ -610,3 +619,53 @@ def get_ffmpeg_library_versions(): @register_fake("torchcodec_ns::_get_backend_details") def _get_backend_details_abstract(decoder: torch.Tensor) -> str: return "" + + +@register_fake("torchcodec_ns::decode_wav_from_file") +def decode_wav_from_file_abstract( + filename: str, + start_seconds: float = 0.0, + stop_seconds: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + samples_size = [get_ctx().new_dynamic_size() for _ in range(2)] + return (torch.empty(samples_size), torch.empty([], dtype=torch.float64)) + + +@register_fake("torchcodec_ns::decode_wav_from_tensor") +def decode_wav_from_tensor_abstract( + data: torch.Tensor, + start_seconds: float = 0.0, + stop_seconds: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + samples_size = [get_ctx().new_dynamic_size() for _ in range(2)] + return (torch.empty(samples_size), torch.empty([], dtype=torch.float64)) + + +@register_fake("torchcodec_ns::_decode_wav_from_file_like") +def _decode_wav_from_file_like_abstract( + ctx: int, + start_seconds: float = 0.0, + stop_seconds: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + samples_size = [get_ctx().new_dynamic_size() for _ in range(2)] + return (torch.empty(samples_size), torch.empty([], dtype=torch.float64)) + + +@register_fake("torchcodec_ns::get_wav_metadata_from_file") +def get_wav_metadata_from_file_abstract( + filename: str, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, +) -> str: + return "" + + +@register_fake("torchcodec_ns::get_wav_metadata_from_tensor") +def get_wav_metadata_from_tensor_abstract( + data: torch.Tensor, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, +) -> str: + return "" diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index f8012cd46..7a77b2359 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -6,16 +6,16 @@ import io -import json from pathlib import Path import torch from torch import Tensor from torchcodec import _core as core, AudioSamples -from torchcodec._core._metadata import AudioStreamMetadata from torchcodec.decoders._decoder_utils import ( + _is_uncompressed_wav, create_decoder, + decode_wav, ERROR_REPORTING_INSTRUCTIONS, ) @@ -61,28 +61,30 @@ def __init__( stream_index: int | None = None, sample_rate: int | None = None, num_channels: int | None = None, + use_wav_decoder: ( + bool | None + ) = None, # optionally disable wav decoder for testing ): torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder") - # Try WAV fast path - self._wav_samples: AudioSamples | None = None - samples, metadata_json = self._decode_wav( - source, stream_index, sample_rate, num_channels - ) - if metadata_json: - metadata = json.loads(metadata_json) - self._wav_samples = AudioSamples( - data=samples, - pts_seconds=0.0, - duration_seconds=metadata["durationSeconds"], - sample_rate=metadata["sampleRate"], + # Check if this is an uncompressed WAV file that we can decode directly with WavDecoder. + self._use_wav_decoder = False + self._wav_source = None + + if use_wav_decoder is not False and ( + wav_metadata := _is_uncompressed_wav( + source, stream_index, sample_rate, num_channels ) + ): + self._use_wav_decoder = True + self._wav_source = source self.stream_index = 0 - self._desired_sample_rate = metadata["sampleRate"] - self._decoder = None # type: ignore[assignment] - self.metadata = AudioStreamMetadata.from_json(metadata) + + # Create metadata from WAV JSON + self.metadata = core.create_audio_metadata_from_wav(wav_metadata) return + # Fall back to FFmpeg decoder self._decoder = create_decoder(source=source, seek_mode="approximate") container_metadata = core.get_container_metadata(self._decoder) @@ -118,55 +120,6 @@ def __init__( num_channels=num_channels, ) - @staticmethod - def _decode_wav( - source: str | Path | io.RawIOBase | io.BufferedReader | bytes | Tensor, - stream_index: int | None = None, - sample_rate: int | None = None, - num_channels: int | None = None, - ) -> tuple[Tensor, str]: - """Decode WAV if valid and parameters match. - - Returns (samples, metadata_json). - Empty metadata_json means not a valid WAV file or parameters don't match. - """ - if isinstance(source, Tensor): - return core.validate_and_decode_wav_from_tensor( - source, - stream_index=stream_index, - sample_rate=sample_rate, - num_channels=num_channels, - ) - elif isinstance(source, bytes): - return core.validate_and_decode_wav_from_tensor( - torch.frombuffer(source, dtype=torch.uint8), - stream_index=stream_index, - sample_rate=sample_rate, - num_channels=num_channels, - ) - elif isinstance(source, (str, Path)): - path = str(source) - if path.startswith(("http://", "https://", "s3://")): - return torch.empty(0), "" - return core.validate_and_decode_wav_from_file( - path, - stream_index=stream_index, - sample_rate=sample_rate, - num_channels=num_channels, - ) - elif isinstance(source, (io.RawIOBase, io.BufferedReader)) or ( - hasattr(source, "read") and hasattr(source, "seek") - ): - data = source.read() - source.seek(0) - return core.validate_and_decode_wav_from_tensor( - torch.frombuffer(data, dtype=torch.uint8), - stream_index=stream_index, - sample_rate=sample_rate, - num_channels=num_channels, - ) - return torch.empty(0), "" - def get_all_samples(self) -> AudioSamples: """Returns all the audio samples from the source. @@ -176,9 +129,6 @@ def get_all_samples(self) -> AudioSamples: Returns: AudioSamples: The samples within the file. """ - # Use WAV fast path if available - if self._wav_samples is not None: - return self._wav_samples return self.get_samples_played_in_range() def get_samples_played_in_range( @@ -209,29 +159,25 @@ def get_samples_played_in_range( f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." ) - # Handle WAV fast path - if self._wav_samples is not None: - sample_rate = self._wav_samples.sample_rate - num_samples = self._wav_samples.data.shape[1] - - start_sample = round(start_seconds * sample_rate) - if stop_seconds is None: - stop_sample = num_samples - else: - stop_sample = round(stop_seconds * sample_rate) - - start_sample = max(0, min(start_sample, num_samples)) - stop_sample = max(0, min(stop_sample, num_samples)) + # Use native WAV decoder if available + if self._use_wav_decoder and self._wav_source is not None: + frames, pts_tensor = decode_wav( + self._wav_source, + start_seconds=start_seconds, + stop_seconds=stop_seconds, + ) + first_pts = pts_tensor.item() + sample_rate = self.metadata.sample_rate + assert sample_rate is not None # mypy - data = self._wav_samples.data[:, start_sample:stop_sample] - output_pts = start_sample / sample_rate return AudioSamples( - data=data, - pts_seconds=output_pts, - duration_seconds=data.shape[1] / sample_rate, + data=frames, + pts_seconds=first_pts, + duration_seconds=frames.shape[1] / sample_rate, sample_rate=sample_rate, ) + # FFmpeg path frames, first_pts = core.get_frames_by_pts_in_range_audio( self._decoder, start_seconds=start_seconds, diff --git a/src/torchcodec/decoders/_decoder_utils.py b/src/torchcodec/decoders/_decoder_utils.py index 6ca53a330..fc7097418 100644 --- a/src/torchcodec/decoders/_decoder_utils.py +++ b/src/torchcodec/decoders/_decoder_utils.py @@ -7,11 +7,13 @@ import contextvars import io +import json from collections.abc import Generator from contextlib import contextmanager from pathlib import Path +import torch from torch import Tensor from torchcodec import _core as core @@ -111,3 +113,103 @@ def set_cuda_backend(backend: str) -> Generator[None, None, None]: def _get_cuda_backend() -> str: return _CUDA_BACKEND.get() + + +def _is_uncompressed_wav( + source, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, +) -> dict | None: + """Check if source is an uncompressed WAV file compatible with native decoder. + + Returns metadata dict if compatible, None otherwise (not WAV, unsupported format, + or requested parameters don't match the source). + """ + try: + if isinstance(source, str): + metadata_json = core.get_wav_metadata_from_file( + source, stream_index, sample_rate, num_channels + ) + elif isinstance(source, Path): + metadata_json = core.get_wav_metadata_from_file( + str(source), stream_index, sample_rate, num_channels + ) + elif isinstance(source, bytes): + buffer = torch.frombuffer(source, dtype=torch.uint8) + metadata_json = core.get_wav_metadata_from_tensor( + buffer, stream_index, sample_rate, num_channels + ) + elif isinstance(source, Tensor): + metadata_json = core.get_wav_metadata_from_tensor( + source, stream_index, sample_rate, num_channels + ) + else: + # File-like objects - read all data to get full metadata + current_pos = source.seek(0, io.SEEK_CUR) + source.seek(0) + data = source.read() + source.seek(current_pos) + if len(data) < 12: + return None + buffer = torch.frombuffer(data, dtype=torch.uint8) + metadata_json = core.get_wav_metadata_from_tensor( + buffer, stream_index, sample_rate, num_channels + ) + + if not metadata_json: + return None + + return json.loads(metadata_json) + except Exception: + # In the case of an error, fall back to FFmpeg decoder + return None + + +def decode_wav( + source: str | Path | bytes | Tensor | io.RawIOBase | io.BufferedReader, + start_seconds: float = 0.0, + stop_seconds: float | None = None, +) -> tuple[Tensor, Tensor]: + """Decode audio samples from a WAV file using the native decoder. + + Args: + source: The WAV audio source - can be a file path (str or Path), + raw bytes, a uint8 tensor containing WAV data, or a file-like object. + start_seconds: Start time in seconds for the audio range. + stop_seconds: Stop time in seconds (exclusive). None means decode to end. + + Returns: + A tuple of (samples, pts_seconds) where: + - samples: Float32 tensor of shape (num_channels, num_samples) normalized to [-1, 1] + - pts_seconds: Float64 tensor containing the PTS of the first sample + + Raises: + RuntimeError: If the WAV format is not supported (compressed formats). + """ + import warnings + + if isinstance(source, str): + return core.decode_wav_from_file(source, start_seconds, stop_seconds) + elif isinstance(source, Path): + return core.decode_wav_from_file(str(source), start_seconds, stop_seconds) + elif isinstance(source, bytes): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + buffer = torch.frombuffer(source, dtype=torch.uint8) + return core.decode_wav_from_tensor(buffer, start_seconds, stop_seconds) + elif isinstance(source, Tensor): + return core.decode_wav_from_tensor(source, start_seconds, stop_seconds) + elif hasattr(source, "read") and hasattr(source, "seek"): + # File-like object - read all data and pass to tensor version + source.seek(0) + data = source.read() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + buffer = torch.frombuffer(data, dtype=torch.uint8) + return core.decode_wav_from_tensor(buffer, start_seconds, stop_seconds) + else: + raise TypeError( + f"Unsupported source type: {type(source)}. " + "Expected str, Path, bytes, Tensor, or file-like object." + ) diff --git a/test/test_decoders.py b/test/test_decoders.py index 2abc6c26b..755033600 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -2318,35 +2318,15 @@ def test_num_channels_errors(self, asset): # which causes SwrContext to fail to initialize. decoder.get_all_samples() - # WAV fast path tests - @pytest.mark.parametrize("asset", (SINE_MONO_S16, SINE_MONO_S32)) - def test_wav_fast_path_from_bytes(self, asset): - """Test that WAV files use the fast path when loaded from bytes.""" - with open(asset.path, "rb") as f: - wav_bytes = f.read() - - decoder = AudioDecoder(wav_bytes) - samples = decoder.get_all_samples() - assert samples.data.dtype == torch.float32 - assert samples.data.shape[0] == asset.num_channels - assert samples.sample_rate == asset.sample_rate - - @pytest.mark.parametrize("asset", (SINE_MONO_S16, SINE_MONO_S32)) - def test_wav_fast_path_range_decoding(self, asset): - """Test that range decoding works correctly with fast path.""" - with open(asset.path, "rb") as f: - wav_bytes = f.read() - - decoder = AudioDecoder(wav_bytes) - - # Decode a range - start_seconds = 1.0 - stop_seconds = 2.0 - samples = decoder.get_samples_played_in_range( - start_seconds=start_seconds, stop_seconds=stop_seconds + @pytest.mark.parametrize( + "asset", (SINE_MONO_S16, SINE_MONO_S32, SINE_16_CHANNEL_S16) + ) + def test_native_matches_ffmpeg_full(self, asset): + native_decoder = AudioDecoder(asset.path, use_wav_decoder=True) + native_samples = native_decoder.get_all_samples() + ffmpeg_decoder = AudioDecoder(asset.path, use_wav_decoder=False) + ffmpeg_samples = ffmpeg_decoder.get_all_samples() + torch.testing.assert_close( + native_samples.data, ffmpeg_samples.data, atol=0, rtol=0 ) - - expected_num_samples = round((stop_seconds - start_seconds) * asset.sample_rate) - assert samples.data.shape[1] == expected_num_samples - assert samples.pts_seconds == start_seconds - assert samples.duration_seconds == pytest.approx(stop_seconds - start_seconds) + assert native_samples.sample_rate == ffmpeg_samples.sample_rate From 055dfc0c4a7f4d82ab39f1459977aa502e31d82d Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 9 Feb 2026 22:27:42 -0500 Subject: [PATCH 6/7] sample_format metadata map --- src/torchcodec/_core/custom_ops.cpp | 42 ++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index f11433be5..04f171324 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -1164,7 +1164,47 @@ std::string buildWavMetadataJson(WavDecoder& decoder) { map["bit_rate"] = "null"; map["codec"] = "\"pcm\""; map["stream_index"] = "0"; - map["sample_format"] = "null"; + // Derive sample_format string matching FFmpeg's av_get_sample_fmt_name(). + // https://ffmpeg.org/doxygen/6.1/group__lavu__sampfmts.html#gaf9a51ca15301871723577c730b5865c5 + // For WAV_FORMAT_EXTENSIBLE, the actual format is in subFormat. + uint16_t effectiveFormat = header.audioFormat; + if (effectiveFormat == WAV_FORMAT_EXTENSIBLE) { + effectiveFormat = header.subFormat; + } + if (effectiveFormat == WAV_FORMAT_PCM) { + switch (header.bitsPerSample) { + case 8: + map["sample_format"] = "\"u8\""; + break; + case 16: + map["sample_format"] = "\"s16\""; + break; + case 24: + // FFmpeg has no s24; it decodes 24-bit PCM as s32. + map["sample_format"] = "\"s32\""; + break; + case 32: + map["sample_format"] = "\"s32\""; + break; + default: + map["sample_format"] = "null"; + break; + } + } else if (effectiveFormat == WAV_FORMAT_IEEE_FLOAT) { + switch (header.bitsPerSample) { + case 32: + map["sample_format"] = "\"flt\""; + break; + case 64: + map["sample_format"] = "\"dbl\""; + break; + default: + map["sample_format"] = "null"; + break; + } + } else { + map["sample_format"] = "null"; + } return mapToJson(map); } From c01d37067c450c2644c4d7cca4806aaf9e92907d Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 10 Feb 2026 01:12:31 -0500 Subject: [PATCH 7/7] add comment to convertSamplesToFloat --- src/torchcodec/_core/WavDecoder.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/WavDecoder.cpp b/src/torchcodec/_core/WavDecoder.cpp index 29347fb00..7dc559b97 100644 --- a/src/torchcodec/_core/WavDecoder.cpp +++ b/src/torchcodec/_core/WavDecoder.cpp @@ -279,9 +279,17 @@ torch::Tensor WavDecoder::convertSamplesToFloat( effectiveFormat = header_.subFormat; } + // WAV stores samples interleaved: [L R L R ...]. These loops convert to + // float and deinterleave into channel-first layout: (numChannels, numSamples) + // in a single pass to avoid intermediate allocations. + // + // Example with 2 channels (L, R) and 3 samples: + // Input (interleaved): [L0 R0 L1 R1 L2 R2] + // ^read: s * numChannels + c = 0,1,2,3,4,5 + // Output (channel-first): [L0 L1 L2 R0 R1 R2] + // ^write: c * numSamples + s = 0,1,2,3,4,5 if (effectiveFormat == WAV_FORMAT_IEEE_FLOAT) { if (header_.bitsPerSample == 32) { - // 32-bit float - just copy and deinterleave const float* floatSrc = reinterpret_cast(src); for (int64_t s = 0; s < numSamples; ++s) { for (int64_t c = 0; c < numChannels; ++c) { @@ -289,7 +297,6 @@ torch::Tensor WavDecoder::convertSamplesToFloat( } } } else if (header_.bitsPerSample == 64) { - // 64-bit float - convert to 32-bit and deinterleave const double* doubleSrc = reinterpret_cast(src); for (int64_t s = 0; s < numSamples; ++s) { for (int64_t c = 0; c < numChannels; ++c) { @@ -299,7 +306,6 @@ torch::Tensor WavDecoder::convertSamplesToFloat( } } } else { - // PCM format - convert to normalized float for (int64_t s = 0; s < numSamples; ++s) { for (int64_t c = 0; c < numChannels; ++c) { const uint8_t* samplePtr = src + (s * numChannels + c) * bytesPerSample;