diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 67d5bb5e2..95763ebfa 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -136,6 +136,7 @@ function(make_torchcodec_libraries ValidationUtils.cpp Transform.cpp Metadata.cpp + WavDecoder.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/WavDecoder.cpp b/src/torchcodec/_core/WavDecoder.cpp new file mode 100644 index 000000000..7dc559b97 --- /dev/null +++ b/src/torchcodec/_core/WavDecoder.cpp @@ -0,0 +1,412 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "WavDecoder.h" + +#include +#include +#include +#include + +namespace facebook::torchcodec { +namespace { + +template +T readLittleEndian(const uint8_t* data) { + T value; + std::memcpy(&value, data, sizeof(T)); + return value; +} + +bool checkFourCC(const uint8_t* data, const char* expected) { + return std::memcmp(data, expected, 4) == 0; +} + +} // namespace + +// WavFileReader implementation +WavFileReader::WavFileReader(const std::string& path) : file_(nullptr) { + file_ = std::fopen(path.c_str(), "rb"); + if (!file_) { + throw std::runtime_error("Failed to open WAV file: " + path); + } +} + +WavFileReader::~WavFileReader() { + if (file_) { + std::fclose(file_); + } +} + +int64_t WavFileReader::read(void* buffer, int64_t size) { + if (!file_) { + return -1; + } + size_t bytesRead = std::fread(buffer, 1, static_cast(size), file_); + return static_cast(bytesRead); +} + +int64_t WavFileReader::seek(int64_t position) { + if (!file_) { + return -1; + } + if (std::fseek(file_, static_cast(position), SEEK_SET) != 0) { + return -1; + } + return position; +} + +// WavTensorReader implementation +WavTensorReader::WavTensorReader(const torch::Tensor& data) + : data_(data), currentPos_(0) { + TORCH_CHECK(data.is_contiguous(), "WAV data tensor must be contiguous"); + TORCH_CHECK( + data.scalar_type() == torch::kUInt8, "WAV data tensor must be uint8"); +} + +int64_t WavTensorReader::read(void* buffer, int64_t size) { + int64_t available = data_.numel() - currentPos_; + int64_t toRead = std::min(size, available); + if (toRead <= 0) { + return 0; + } + + const uint8_t* src = data_.data_ptr() + currentPos_; + std::memcpy(buffer, src, static_cast(toRead)); + currentPos_ += toRead; + return toRead; +} + +int64_t WavTensorReader::seek(int64_t position) { + if (position < 0 || position > data_.numel()) { + return -1; + } + currentPos_ = position; + return currentPos_; +} + +// WavDecoder implementation +WavDecoder::WavDecoder(std::unique_ptr reader) + : reader_(std::move(reader)) { + parseHeader(); +} + +bool WavDecoder::isWavFile(const void* data, size_t size) { + if (size < 12) { + return false; + } + const uint8_t* bytes = static_cast(data); + // Check for RIFF....WAVE + return checkFourCC(bytes, "RIFF") && checkFourCC(bytes + 8, "WAVE"); +} + +void WavDecoder::parseHeader() { + // Read enough for header parsing (typical WAV headers are < 100 bytes) + // TODO: source? + constexpr int64_t headerBufferSize = 256; + std::vector buffer(headerBufferSize); + + reader_->seek(0); + int64_t bytesRead = reader_->read(buffer.data(), headerBufferSize); + if (bytesRead < 44) { + throw std::runtime_error("WAV data too small to contain valid header"); + } + + const uint8_t* data = buffer.data(); + + // Verify RIFF header + if (!checkFourCC(data, "RIFF")) { + throw std::runtime_error("Missing RIFF header"); + } + + // Verify WAVE format + if (!checkFourCC(data + 8, "WAVE")) { + throw std::runtime_error("Missing WAVE format identifier"); + } + + // Find and parse fmt chunk + int64_t offset = 12; + bool foundFmt = false; + + while (offset + 8 <= bytesRead) { + if (checkFourCC(data + offset, "fmt ")) { + uint32_t fmtSize = readLittleEndian(data + offset + 4); + + if (offset + 8 + fmtSize > bytesRead) { + throw std::runtime_error("fmt chunk extends beyond buffer"); + } + + if (fmtSize < 16) { + throw std::runtime_error("fmt chunk too small"); + } + + const uint8_t* fmtData = data + offset + 8; + // TODO: explain https://en.wikipedia.org/wiki/WAV#WAV_file_header + header_.audioFormat = readLittleEndian(fmtData); + header_.numChannels = readLittleEndian(fmtData + 2); + header_.sampleRate = readLittleEndian(fmtData + 4); + header_.byteRate = readLittleEndian(fmtData + 8); + header_.blockAlign = readLittleEndian(fmtData + 12); + header_.bitsPerSample = readLittleEndian(fmtData + 14); + + // Parse extended format fields for WAVE_FORMAT_EXTENSIBLE + if (header_.audioFormat == WAV_FORMAT_EXTENSIBLE) { + // Extended format requires at least 40 bytes total (16 base + 2 cbSize + // + 22 extension) + if (fmtSize < 40) { + throw std::runtime_error( + "WAVE_FORMAT_EXTENSIBLE fmt chunk too small"); + } + + header_.validBitsPerSample = readLittleEndian(fmtData + 18); + header_.channelMask = readLittleEndian(fmtData + 20); + // SubFormat GUID starts at offset 24, first 2 bytes are the format code + header_.subFormat = readLittleEndian(fmtData + 24); + } + + foundFmt = true; + offset += 8 + fmtSize; + break; + } + // Skip unknown chunks + uint32_t chunkSize = readLittleEndian(data + offset + 4); + offset += 8 + chunkSize; + } + + if (!foundFmt) { + throw std::runtime_error("fmt chunk not found"); + } + + while (offset + 8 <= bytesRead) { + if (checkFourCC(data + offset, "data")) { + // Parse data chunk + header_.dataSize = readLittleEndian(data + offset + 4); + header_.dataOffset = offset + 8; + return; + } + + // Skip this chunk + uint32_t chunkSize = readLittleEndian(data + offset + 4); + offset += 8 + chunkSize; + } + + throw std::runtime_error("data chunk not found"); +} + +bool WavDecoder::isSupported() const { + // Determine effective format (subFormat for extensible, audioFormat + // otherwise) + uint16_t effectiveFormat = header_.audioFormat; + if (header_.audioFormat == WAV_FORMAT_EXTENSIBLE) { + effectiveFormat = header_.subFormat; + } + + // Support PCM and IEEE float formats + if (effectiveFormat != WAV_FORMAT_PCM && + effectiveFormat != WAV_FORMAT_IEEE_FLOAT) { + return false; + } + + // Validate bits per sample + if (effectiveFormat == WAV_FORMAT_PCM) { + if (header_.bitsPerSample != 8 && header_.bitsPerSample != 16 && + header_.bitsPerSample != 24 && header_.bitsPerSample != 32) { + return false; + } + } else if (effectiveFormat == WAV_FORMAT_IEEE_FLOAT) { + if (header_.bitsPerSample != 32 && header_.bitsPerSample != 64) { + return false; + } + } + + return header_.numChannels > 0 && header_.sampleRate > 0 && + header_.blockAlign > 0; +} + +bool WavDecoder::isCompatible( + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) const { + // WAV files only have one stream at index 0 + if (stream_index.has_value() && stream_index.value() != 0) { + return false; + } + // Check sample rate matches if specified (no resampling support) + if (sample_rate.has_value() && + sample_rate.value() != static_cast(header_.sampleRate)) { + return false; + } + // Check channel count matches if specified (no remixing support) + if (num_channels.has_value() && + num_channels.value() != static_cast(header_.numChannels)) { + return false; + } + return true; +} + +const WavHeader& WavDecoder::getHeader() const { + return header_; +} + +double WavDecoder::getDurationSeconds() const { + if (header_.blockAlign == 0 || header_.sampleRate == 0) { + return 0.0; + } + int64_t numSamples = + static_cast(header_.dataSize) / header_.blockAlign; + return static_cast(numSamples) / header_.sampleRate; +} + +torch::Tensor WavDecoder::convertSamplesToFloat( + const void* rawData, + int64_t numSamples, + int64_t numChannels) { + // Output is (numChannels, numSamples) float32 + torch::Tensor output = + torch::empty({numChannels, numSamples}, torch::kFloat32); + float* outPtr = output.data_ptr(); + + const uint8_t* src = static_cast(rawData); + int bytesPerSample = header_.bitsPerSample / 8; + + // Determine effective format (subFormat for extensible, audioFormat + // otherwise) + uint16_t effectiveFormat = header_.audioFormat; + if (header_.audioFormat == WAV_FORMAT_EXTENSIBLE) { + effectiveFormat = header_.subFormat; + } + + // WAV stores samples interleaved: [L R L R ...]. These loops convert to + // float and deinterleave into channel-first layout: (numChannels, numSamples) + // in a single pass to avoid intermediate allocations. + // + // Example with 2 channels (L, R) and 3 samples: + // Input (interleaved): [L0 R0 L1 R1 L2 R2] + // ^read: s * numChannels + c = 0,1,2,3,4,5 + // Output (channel-first): [L0 L1 L2 R0 R1 R2] + // ^write: c * numSamples + s = 0,1,2,3,4,5 + if (effectiveFormat == WAV_FORMAT_IEEE_FLOAT) { + if (header_.bitsPerSample == 32) { + const float* floatSrc = reinterpret_cast(src); + for (int64_t s = 0; s < numSamples; ++s) { + for (int64_t c = 0; c < numChannels; ++c) { + outPtr[c * numSamples + s] = floatSrc[s * numChannels + c]; + } + } + } else if (header_.bitsPerSample == 64) { + const double* doubleSrc = reinterpret_cast(src); + for (int64_t s = 0; s < numSamples; ++s) { + for (int64_t c = 0; c < numChannels; ++c) { + outPtr[c * numSamples + s] = + static_cast(doubleSrc[s * numChannels + c]); + } + } + } + } else { + for (int64_t s = 0; s < numSamples; ++s) { + for (int64_t c = 0; c < numChannels; ++c) { + const uint8_t* samplePtr = src + (s * numChannels + c) * bytesPerSample; + float value = 0.0f; + + switch (header_.bitsPerSample) { + case 8: { + // 8-bit PCM is unsigned (0-255, center at 128) + uint8_t sample = *samplePtr; + value = (static_cast(sample) - 128.0f) / 128.0f; + break; + } + case 16: { + // 16-bit PCM is signed + int16_t sample = readLittleEndian(samplePtr); + value = static_cast(sample) / 32768.0f; + break; + } + case 24: { + // 24-bit PCM is signed, stored in 3 bytes little-endian + int32_t sample = static_cast(samplePtr[0]) | + (static_cast(samplePtr[1]) << 8) | + (static_cast(samplePtr[2]) << 16); + // Sign extend from 24 to 32 bits + if (sample & 0x800000) { + sample |= 0xFF000000; + } + value = static_cast(sample) / 8388608.0f; + break; + } + case 32: { + // 32-bit PCM is signed + int32_t sample = readLittleEndian(samplePtr); + value = static_cast(sample) / 2147483648.0f; + break; + } + } + outPtr[c * numSamples + s] = value; + } + } + } + + return output; +} + +std::tuple WavDecoder::getSamplesInRange( + double startSeconds, + std::optional stopSeconds) { + TORCH_CHECK(startSeconds >= 0, "start_seconds must be non-negative"); + if (stopSeconds.has_value()) { + TORCH_CHECK( + stopSeconds.value() >= startSeconds, + "stop_seconds must be >= start_seconds"); + } + + double duration = getDurationSeconds(); + if (startSeconds >= duration) { + // Return empty tensor + return std::make_tuple( + torch::empty({header_.numChannels, 0}, torch::kFloat32), startSeconds); + } + + double actualStop = stopSeconds.value_or(duration); + actualStop = std::min(actualStop, duration); + + // Calculate sample range + int64_t startSample = static_cast(startSeconds * header_.sampleRate); + int64_t stopSample = static_cast(actualStop * header_.sampleRate); + int64_t numSamples = stopSample - startSample; + + if (numSamples <= 0) { + return std::make_tuple( + torch::empty({header_.numChannels, 0}, torch::kFloat32), startSeconds); + } + + // Calculate byte positions + int64_t byteOffset = startSample * header_.blockAlign; + int64_t bytesToRead = numSamples * header_.blockAlign; + + // Seek to position and read + reader_->seek(static_cast(header_.dataOffset) + byteOffset); + std::vector rawData(bytesToRead); + int64_t bytesRead = reader_->read(rawData.data(), bytesToRead); + + if (bytesRead < bytesToRead) { + // Adjust numSamples if we couldn't read everything + numSamples = bytesRead / header_.blockAlign; + if (numSamples <= 0) { + return std::make_tuple( + torch::empty({header_.numChannels, 0}, torch::kFloat32), + startSeconds); + } + } + + torch::Tensor samples = + convertSamplesToFloat(rawData.data(), numSamples, header_.numChannels); + + // Calculate actual PTS + double ptsSeconds = static_cast(startSample) / header_.sampleRate; + + return std::make_tuple(samples, ptsSeconds); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/WavDecoder.h b/src/torchcodec/_core/WavDecoder.h new file mode 100644 index 000000000..58c0d6932 --- /dev/null +++ b/src/torchcodec/_core/WavDecoder.h @@ -0,0 +1,120 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include + +namespace facebook::torchcodec { + +// WAV format constants +constexpr uint16_t WAV_FORMAT_PCM = 1; +constexpr uint16_t WAV_FORMAT_IEEE_FLOAT = 3; +constexpr uint16_t WAV_FORMAT_EXTENSIBLE = 0xFFFE; + +// Parsed WAV header information +struct WavHeader { + uint16_t audioFormat = 0; // 1 = PCM, 3 = IEEE float + uint16_t numChannels = 0; + uint32_t sampleRate = 0; + uint32_t byteRate = 0; + uint16_t blockAlign = 0; + uint16_t bitsPerSample = 0; + uint64_t dataOffset = 0; // Offset to start of audio data + uint64_t dataSize = 0; // Size of audio data in bytes + + // Extended format fields (WAVE_FORMAT_EXTENSIBLE) + uint16_t validBitsPerSample = 0; + uint32_t channelMask = 0; + uint16_t subFormat = 0; // Extracted from SubFormat GUID (first 2 bytes) +}; + +// Abstract base class for reading WAV data from different sources +class WavReader { + public: + virtual ~WavReader() = default; + + // Read up to `size` bytes into `buffer`. Returns bytes actually read. + virtual int64_t read(void* buffer, int64_t size) = 0; + + // Seek to absolute position. Returns new position or -1 on error. + virtual int64_t seek(int64_t position) = 0; +}; + +// WavReader implementation for file paths +class WavFileReader : public WavReader { + public: + explicit WavFileReader(const std::string& path); + ~WavFileReader() override; + + int64_t read(void* buffer, int64_t size) override; + int64_t seek(int64_t position) override; + + private: + std::FILE* file_; +}; + +// WavReader implementation for tensor/bytes data +class WavTensorReader : public WavReader { + public: + explicit WavTensorReader(const torch::Tensor& data); + + int64_t read(void* buffer, int64_t size) override; + int64_t seek(int64_t position) override; + + private: + torch::Tensor data_; + int64_t currentPos_; +}; + +// Main WAV decoder class +class WavDecoder { + public: + explicit WavDecoder(std::unique_ptr reader); + + // Check if this is a supported uncompressed WAV file + // Returns true for PCM and IEEE float formats + bool isSupported() const; + + // Check if the requested parameters are compatible with this WAV file. + // Returns false if resampling or channel mixing would be required. + // WAV files only have one stream, so stream_index must be 0 or nullopt. + bool isCompatible( + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) const; + + // Get the parsed header + const WavHeader& getHeader() const; + + // Get samples in a time range, returns (samples, pts_seconds) + // samples is shape (num_channels, num_samples) float32 normalized to [-1, 1] + std::tuple getSamplesInRange( + double startSeconds, + std::optional stopSeconds); + + // Get total duration in seconds + double getDurationSeconds() const; + + // Static helper to check if data looks like a WAV file + static bool isWavFile(const void* data, size_t size); + + private: + void parseHeader(); + torch::Tensor convertSamplesToFloat( + const void* rawData, + int64_t numSamples, + int64_t numChannels); + + std::unique_ptr reader_; + WavHeader header_; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index add9efa90..8c0f357fa 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -8,12 +8,14 @@ from ._metadata import ( AudioStreamMetadata, ContainerMetadata, + create_audio_metadata_from_wav, get_container_metadata, get_container_metadata_from_header, VideoStreamMetadata, ) from .ops import ( _add_video_stream, + _decode_wav_from_file_like, _get_backend_details, _get_key_frame_indices, _test_frame_pts_equality, @@ -24,6 +26,8 @@ create_from_file, create_from_file_like, create_from_tensor, + decode_wav_from_file, + decode_wav_from_tensor, encode_audio_to_file, encode_audio_to_file_like, encode_audio_to_tensor, @@ -41,6 +45,8 @@ get_frames_in_range, get_json_metadata, get_next_frame, + get_wav_metadata_from_file, + get_wav_metadata_from_tensor, scan_all_streams_to_update_metadata, seek_to_pts, ) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 482d0e1cb..c0200caac 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -140,6 +140,40 @@ class AudioStreamMetadata(StreamMetadata): def __repr__(self): return super().__repr__() + @classmethod + def from_json(cls, json: dict, stream_index: int = 0) -> "AudioStreamMetadata": + """Create AudioStreamMetadata from a JSON dictionary returned by""" + return cls( + duration_seconds_from_header=json.get("durationSecondsFromHeader"), + duration_seconds=json.get("durationSeconds"), + bit_rate=json.get("bitRate"), + begin_stream_seconds_from_header=json.get("beginStreamSecondsFromHeader"), + begin_stream_seconds=json.get("beginStreamSeconds"), + codec=json.get("codec"), + stream_index=stream_index, + sample_rate=json.get("sampleRate"), + num_channels=json.get("numChannels"), + sample_format=json.get("sampleFormat"), + ) + + +def create_audio_metadata_from_wav(wav_json: dict) -> AudioStreamMetadata: + """Create AudioStreamMetadata from WAV metadata dict.""" + return AudioStreamMetadata( + duration_seconds_from_header=wav_json.get("duration_seconds"), + begin_stream_seconds_from_header=wav_json.get( + "begin_stream_seconds_from_header", 0.0 + ), + bit_rate=wav_json.get("bit_rate"), + codec=wav_json.get("codec", "pcm"), + stream_index=wav_json.get("stream_index", 0), + duration_seconds=wav_json.get("duration_seconds"), + begin_stream_seconds=wav_json.get("begin_stream_seconds", 0.0), + sample_rate=wav_json.get("sample_rate"), + num_channels=wav_json.get("num_channels"), + sample_format=wav_json.get("sample_format"), + ) + @dataclass class ContainerMetadata: diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index e35f62388..04f171324 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -14,6 +14,7 @@ #include "Encoder.h" #include "SingleStreamDecoder.h" #include "ValidationUtils.h" +#include "WavDecoder.h" #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" @@ -79,6 +80,16 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool"); m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); + m.def( + "decode_wav_from_file(str filename, float start_seconds=0.0, float? stop_seconds=None) -> (Tensor, Tensor)"); + m.def( + "decode_wav_from_tensor(Tensor data, float start_seconds=0.0, float? stop_seconds=None) -> (Tensor, Tensor)"); + m.def( + "_decode_wav_from_file_like(int ctx, float start_seconds=0.0, float? stop_seconds=None) -> (Tensor, Tensor)"); + m.def( + "get_wav_metadata_from_file(str filename, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> str"); + m.def( + "get_wav_metadata_from_tensor(Tensor data, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> str"); } namespace { @@ -1052,6 +1063,184 @@ void scan_all_streams_to_update_metadata(torch::Tensor& decoder) { videoDecoder->scanFileAndUpdateMetadataAndIndex(); } +// The elements of this tuple are: +// 1. The audio samples as a float32 tensor of shape (num_channels, +// num_samples) +// 2. A single float value for the pts of the first sample, in seconds. +using OpsWavOutput = std::tuple; + +OpsWavOutput decode_wav_from_file( + std::string_view filename, + double start_seconds, + std::optional stop_seconds) { + auto reader = std::make_unique(std::string(filename)); + WavDecoder decoder(std::move(reader)); + + TORCH_CHECK( + decoder.isSupported(), + "Unsupported WAV format. Only PCM and IEEE float formats are supported."); + + auto [samples, pts] = decoder.getSamplesInRange(start_seconds, stop_seconds); + return std::make_tuple( + samples, torch::tensor(pts, torch::dtype(torch::kFloat64))); +} + +OpsWavOutput decode_wav_from_tensor( + const torch::Tensor& data, + double start_seconds, + std::optional stop_seconds) { + auto reader = std::make_unique(data); + WavDecoder decoder(std::move(reader)); + + TORCH_CHECK( + decoder.isSupported(), + "Unsupported WAV format. Only PCM and IEEE float formats are supported."); + + auto [samples, pts] = decoder.getSamplesInRange(start_seconds, stop_seconds); + return std::make_tuple( + samples, torch::tensor(pts, torch::dtype(torch::kFloat64))); +} + +OpsWavOutput _decode_wav_from_file_like( + int64_t file_like_context, + double start_seconds, + std::optional stop_seconds) { + // Get the file-like object and read all data into a tensor + auto fileLikeContext = + reinterpret_cast(file_like_context); + TORCH_CHECK( + fileLikeContext != nullptr, "file_like_context must be a valid pointer"); + + // Read all data from file-like object + // First, try to get size by seeking to end + fileLikeContext->getAVIOContext()->seek( + fileLikeContext->getAVIOContext(), 0, SEEK_END); + int64_t totalSize = fileLikeContext->getAVIOContext()->pos; + fileLikeContext->getAVIOContext()->seek( + fileLikeContext->getAVIOContext(), 0, SEEK_SET); + + std::vector buffer(totalSize); + int64_t bytesRead = 0; + while (bytesRead < totalSize) { + int ret = fileLikeContext->getAVIOContext()->read_packet( + fileLikeContext->getAVIOContext()->opaque, + buffer.data() + bytesRead, + totalSize - bytesRead); + if (ret <= 0) { + break; + } + bytesRead += ret; + } + + // Create tensor from buffer + torch::Tensor tensorData = + torch::from_blob( + buffer.data(), {static_cast(buffer.size())}, torch::kUInt8) + .clone(); + + auto reader = std::make_unique(tensorData); + WavDecoder decoder(std::move(reader)); + + TORCH_CHECK( + decoder.isSupported(), + "Unsupported WAV format. Only PCM and IEEE float formats are supported."); + + auto [samples, pts] = decoder.getSamplesInRange(start_seconds, stop_seconds); + return std::make_tuple( + samples, torch::tensor(pts, torch::dtype(torch::kFloat64))); +} + +std::string buildWavMetadataJson(WavDecoder& decoder) { + const WavHeader& header = decoder.getHeader(); + std::map map; + // Fields matching AudioStreamMetadata + map["sample_rate"] = std::to_string(header.sampleRate); + map["num_channels"] = std::to_string(header.numChannels); + map["duration_seconds"] = fmt::to_string(decoder.getDurationSeconds()); + map["duration_seconds_from_header"] = + fmt::to_string(decoder.getDurationSeconds()); + map["begin_stream_seconds"] = "0.0"; + map["begin_stream_seconds_from_header"] = "0.0"; + map["bit_rate"] = "null"; + map["codec"] = "\"pcm\""; + map["stream_index"] = "0"; + // Derive sample_format string matching FFmpeg's av_get_sample_fmt_name(). + // https://ffmpeg.org/doxygen/6.1/group__lavu__sampfmts.html#gaf9a51ca15301871723577c730b5865c5 + // For WAV_FORMAT_EXTENSIBLE, the actual format is in subFormat. + uint16_t effectiveFormat = header.audioFormat; + if (effectiveFormat == WAV_FORMAT_EXTENSIBLE) { + effectiveFormat = header.subFormat; + } + if (effectiveFormat == WAV_FORMAT_PCM) { + switch (header.bitsPerSample) { + case 8: + map["sample_format"] = "\"u8\""; + break; + case 16: + map["sample_format"] = "\"s16\""; + break; + case 24: + // FFmpeg has no s24; it decodes 24-bit PCM as s32. + map["sample_format"] = "\"s32\""; + break; + case 32: + map["sample_format"] = "\"s32\""; + break; + default: + map["sample_format"] = "null"; + break; + } + } else if (effectiveFormat == WAV_FORMAT_IEEE_FLOAT) { + switch (header.bitsPerSample) { + case 32: + map["sample_format"] = "\"flt\""; + break; + case 64: + map["sample_format"] = "\"dbl\""; + break; + default: + map["sample_format"] = "null"; + break; + } + } else { + map["sample_format"] = "null"; + } + + return mapToJson(map); +} + +std::string get_wav_metadata_from_file( + std::string_view filename, + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) { + auto reader = std::make_unique(std::string(filename)); + WavDecoder decoder(std::move(reader)); + + if (!decoder.isSupported() || + !decoder.isCompatible(stream_index, sample_rate, num_channels)) { + return ""; + } + + return buildWavMetadataJson(decoder); +} + +std::string get_wav_metadata_from_tensor( + const torch::Tensor& data, + std::optional stream_index, + std::optional sample_rate, + std::optional num_channels) { + auto reader = std::make_unique(data); + WavDecoder decoder(std::move(reader)); + + if (!decoder.isSupported() || + !decoder.isCompatible(stream_index, sample_rate, num_channels)) { + return ""; + } + + return buildWavMetadataJson(decoder); +} + TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); m.impl("create_from_tensor", &create_from_tensor); @@ -1061,6 +1250,11 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("encode_video_to_file", &encode_video_to_file); m.impl("encode_video_to_tensor", &encode_video_to_tensor); m.impl("_encode_video_to_file_like", &_encode_video_to_file_like); + m.impl("decode_wav_from_file", &decode_wav_from_file); + m.impl("decode_wav_from_tensor", &decode_wav_from_tensor); + m.impl("_decode_wav_from_file_like", &_decode_wav_from_file_like); + m.impl("get_wav_metadata_from_file", &get_wav_metadata_from_file); + m.impl("get_wav_metadata_from_tensor", &get_wav_metadata_from_tensor); } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 3188dfc7b..733ac02ac 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -168,6 +168,21 @@ def expose_ffmpeg_dlls(): # noqa: F811 torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default ) _get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default +decode_wav_from_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.decode_wav_from_file.default +) +decode_wav_from_tensor = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.decode_wav_from_tensor.default +) +_decode_wav_from_file_like = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns._decode_wav_from_file_like.default +) +get_wav_metadata_from_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.get_wav_metadata_from_file.default +) +get_wav_metadata_from_tensor = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.get_wav_metadata_from_tensor.default +) # ============================= @@ -604,3 +619,53 @@ def get_ffmpeg_library_versions(): @register_fake("torchcodec_ns::_get_backend_details") def _get_backend_details_abstract(decoder: torch.Tensor) -> str: return "" + + +@register_fake("torchcodec_ns::decode_wav_from_file") +def decode_wav_from_file_abstract( + filename: str, + start_seconds: float = 0.0, + stop_seconds: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + samples_size = [get_ctx().new_dynamic_size() for _ in range(2)] + return (torch.empty(samples_size), torch.empty([], dtype=torch.float64)) + + +@register_fake("torchcodec_ns::decode_wav_from_tensor") +def decode_wav_from_tensor_abstract( + data: torch.Tensor, + start_seconds: float = 0.0, + stop_seconds: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + samples_size = [get_ctx().new_dynamic_size() for _ in range(2)] + return (torch.empty(samples_size), torch.empty([], dtype=torch.float64)) + + +@register_fake("torchcodec_ns::_decode_wav_from_file_like") +def _decode_wav_from_file_like_abstract( + ctx: int, + start_seconds: float = 0.0, + stop_seconds: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + samples_size = [get_ctx().new_dynamic_size() for _ in range(2)] + return (torch.empty(samples_size), torch.empty([], dtype=torch.float64)) + + +@register_fake("torchcodec_ns::get_wav_metadata_from_file") +def get_wav_metadata_from_file_abstract( + filename: str, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, +) -> str: + return "" + + +@register_fake("torchcodec_ns::get_wav_metadata_from_tensor") +def get_wav_metadata_from_tensor_abstract( + data: torch.Tensor, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, +) -> str: + return "" diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index e1d0e0461..7a77b2359 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -13,7 +13,9 @@ from torchcodec import _core as core, AudioSamples from torchcodec.decoders._decoder_utils import ( + _is_uncompressed_wav, create_decoder, + decode_wav, ERROR_REPORTING_INSTRUCTIONS, ) @@ -59,8 +61,30 @@ def __init__( stream_index: int | None = None, sample_rate: int | None = None, num_channels: int | None = None, + use_wav_decoder: ( + bool | None + ) = None, # optionally disable wav decoder for testing ): torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder") + + # Check if this is an uncompressed WAV file that we can decode directly with WavDecoder. + self._use_wav_decoder = False + self._wav_source = None + + if use_wav_decoder is not False and ( + wav_metadata := _is_uncompressed_wav( + source, stream_index, sample_rate, num_channels + ) + ): + self._use_wav_decoder = True + self._wav_source = source + self.stream_index = 0 + + # Create metadata from WAV JSON + self.metadata = core.create_audio_metadata_from_wav(wav_metadata) + return + + # Fall back to FFmpeg decoder self._decoder = create_decoder(source=source, seek_mode="approximate") container_metadata = core.get_container_metadata(self._decoder) @@ -134,6 +158,26 @@ def get_samples_played_in_range( raise ValueError( f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." ) + + # Use native WAV decoder if available + if self._use_wav_decoder and self._wav_source is not None: + frames, pts_tensor = decode_wav( + self._wav_source, + start_seconds=start_seconds, + stop_seconds=stop_seconds, + ) + first_pts = pts_tensor.item() + sample_rate = self.metadata.sample_rate + assert sample_rate is not None # mypy + + return AudioSamples( + data=frames, + pts_seconds=first_pts, + duration_seconds=frames.shape[1] / sample_rate, + sample_rate=sample_rate, + ) + + # FFmpeg path frames, first_pts = core.get_frames_by_pts_in_range_audio( self._decoder, start_seconds=start_seconds, diff --git a/src/torchcodec/decoders/_decoder_utils.py b/src/torchcodec/decoders/_decoder_utils.py index 6ca53a330..fc7097418 100644 --- a/src/torchcodec/decoders/_decoder_utils.py +++ b/src/torchcodec/decoders/_decoder_utils.py @@ -7,11 +7,13 @@ import contextvars import io +import json from collections.abc import Generator from contextlib import contextmanager from pathlib import Path +import torch from torch import Tensor from torchcodec import _core as core @@ -111,3 +113,103 @@ def set_cuda_backend(backend: str) -> Generator[None, None, None]: def _get_cuda_backend() -> str: return _CUDA_BACKEND.get() + + +def _is_uncompressed_wav( + source, + stream_index: int | None = None, + sample_rate: int | None = None, + num_channels: int | None = None, +) -> dict | None: + """Check if source is an uncompressed WAV file compatible with native decoder. + + Returns metadata dict if compatible, None otherwise (not WAV, unsupported format, + or requested parameters don't match the source). + """ + try: + if isinstance(source, str): + metadata_json = core.get_wav_metadata_from_file( + source, stream_index, sample_rate, num_channels + ) + elif isinstance(source, Path): + metadata_json = core.get_wav_metadata_from_file( + str(source), stream_index, sample_rate, num_channels + ) + elif isinstance(source, bytes): + buffer = torch.frombuffer(source, dtype=torch.uint8) + metadata_json = core.get_wav_metadata_from_tensor( + buffer, stream_index, sample_rate, num_channels + ) + elif isinstance(source, Tensor): + metadata_json = core.get_wav_metadata_from_tensor( + source, stream_index, sample_rate, num_channels + ) + else: + # File-like objects - read all data to get full metadata + current_pos = source.seek(0, io.SEEK_CUR) + source.seek(0) + data = source.read() + source.seek(current_pos) + if len(data) < 12: + return None + buffer = torch.frombuffer(data, dtype=torch.uint8) + metadata_json = core.get_wav_metadata_from_tensor( + buffer, stream_index, sample_rate, num_channels + ) + + if not metadata_json: + return None + + return json.loads(metadata_json) + except Exception: + # In the case of an error, fall back to FFmpeg decoder + return None + + +def decode_wav( + source: str | Path | bytes | Tensor | io.RawIOBase | io.BufferedReader, + start_seconds: float = 0.0, + stop_seconds: float | None = None, +) -> tuple[Tensor, Tensor]: + """Decode audio samples from a WAV file using the native decoder. + + Args: + source: The WAV audio source - can be a file path (str or Path), + raw bytes, a uint8 tensor containing WAV data, or a file-like object. + start_seconds: Start time in seconds for the audio range. + stop_seconds: Stop time in seconds (exclusive). None means decode to end. + + Returns: + A tuple of (samples, pts_seconds) where: + - samples: Float32 tensor of shape (num_channels, num_samples) normalized to [-1, 1] + - pts_seconds: Float64 tensor containing the PTS of the first sample + + Raises: + RuntimeError: If the WAV format is not supported (compressed formats). + """ + import warnings + + if isinstance(source, str): + return core.decode_wav_from_file(source, start_seconds, stop_seconds) + elif isinstance(source, Path): + return core.decode_wav_from_file(str(source), start_seconds, stop_seconds) + elif isinstance(source, bytes): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + buffer = torch.frombuffer(source, dtype=torch.uint8) + return core.decode_wav_from_tensor(buffer, start_seconds, stop_seconds) + elif isinstance(source, Tensor): + return core.decode_wav_from_tensor(source, start_seconds, stop_seconds) + elif hasattr(source, "read") and hasattr(source, "seek"): + # File-like object - read all data and pass to tensor version + source.seek(0) + data = source.read() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + buffer = torch.frombuffer(data, dtype=torch.uint8) + return core.decode_wav_from_tensor(buffer, start_seconds, stop_seconds) + else: + raise TypeError( + f"Unsupported source type: {type(source)}. " + "Expected str, Path, bytes, Tensor, or file-like object." + ) diff --git a/test/test_decoders.py b/test/test_decoders.py index 9e901f826..755033600 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -2317,3 +2317,16 @@ def test_num_channels_errors(self, asset): # FFmpeg fails to find a default layout for certain channel counts, # which causes SwrContext to fail to initialize. decoder.get_all_samples() + + @pytest.mark.parametrize( + "asset", (SINE_MONO_S16, SINE_MONO_S32, SINE_16_CHANNEL_S16) + ) + def test_native_matches_ffmpeg_full(self, asset): + native_decoder = AudioDecoder(asset.path, use_wav_decoder=True) + native_samples = native_decoder.get_all_samples() + ffmpeg_decoder = AudioDecoder(asset.path, use_wav_decoder=False) + ffmpeg_samples = ffmpeg_decoder.get_all_samples() + torch.testing.assert_close( + native_samples.data, ffmpeg_samples.data, atol=0, rtol=0 + ) + assert native_samples.sample_rate == ffmpeg_samples.sample_rate