From 512634d7bed1e3c68e0e5ebf01eb0adf5ef7e377 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Sun, 1 Feb 2026 21:43:16 -0800 Subject: [PATCH 01/10] refactor --- src/torchcodec/_core/CMakeLists.txt | 1 + src/torchcodec/_core/CpuDeviceInterface.cpp | 122 +------------------ src/torchcodec/_core/CpuDeviceInterface.h | 32 ++--- src/torchcodec/_core/SwScale.cpp | 125 ++++++++++++++++++++ src/torchcodec/_core/SwScale.h | 45 +++++++ 5 files changed, 182 insertions(+), 143 deletions(-) create mode 100644 src/torchcodec/_core/SwScale.cpp create mode 100644 src/torchcodec/_core/SwScale.h diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 2dc11d237..602eb4b5f 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -128,6 +128,7 @@ function(make_torchcodec_libraries ValidationUtils.cpp Transform.cpp Metadata.cpp + SwScale.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 404d0e8af..1f1c7d3d9 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -198,8 +198,10 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( outputTensor = preAllocatedOutputTensor.value_or( allocateEmptyHWCTensor(outputDims, torch::kCPU)); - int resultHeight = - convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims); + if (!swScale_) { + swScale_ = std::make_unique(swsFlags_); + } + int resultHeight = swScale_->convert(avFrame, outputTensor, outputDims); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. @@ -244,122 +246,6 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( } } -int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor, - const FrameDims& outputDims) { - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); - - bool needsResize = - (avFrame->height != outputDims.height || - avFrame->width != outputDims.width); - - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. - SwsFrameContext swsFrameContext( - avFrame->width, - avFrame->height, - frameFormat, - needsResize ? avFrame->width : outputDims.width, - needsResize ? avFrame->height : outputDims.height); - - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - swsContext_ = createSwsContext( - swsFrameContext, - avFrame->colorspace, - - // See [Transform and Format Conversion Order] for more on the output - // pixel format. - /*outputFormat=*/AV_PIX_FMT_RGB24, - - // No flags for color conversion. When resizing is needed, we use a - // separate swscale context with the appropriate resize flags. - /*swsFlags=*/0); - prevSwsFrameContext_ = swsFrameContext; - } - - // When resizing is needed, we do sws_scale twice: first convert to RGB24 at - // original resolution, then resize in RGB24 space. This ensures transforms - // happen in the output color space (RGB24) rather than the input color space - // (YUV). - // - // When no resize is needed, we do color conversion directly into the output - // tensor. - - torch::Tensor colorConvertedTensor = needsResize - ? allocateEmptyHWCTensor( - FrameDims(avFrame->height, avFrame->width), torch::kCPU) - : outputTensor; - - uint8_t* colorConvertedPointers[4] = { - colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; - int colorConvertedWidth = static_cast(colorConvertedTensor.sizes()[1]); - int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0}; - - int colorConvertedHeight = sws_scale( - swsContext_.get(), - avFrame->data, - avFrame->linesize, - 0, - avFrame->height, - colorConvertedPointers, - colorConvertedLinesizes); - - TORCH_CHECK( - colorConvertedHeight == avFrame->height, - "Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ", - colorConvertedHeight, - " != ", - avFrame->height); - - if (needsResize) { - // Use cached swscale context for resizing, similar to the color conversion - // context caching above. - SwsFrameContext resizeSwsFrameContext( - avFrame->width, - avFrame->height, - AV_PIX_FMT_RGB24, - outputDims.width, - outputDims.height); - - if (!resizeSwsContext_ || - prevResizeSwsFrameContext_ != resizeSwsFrameContext) { - resizeSwsContext_ = createSwsContext( - resizeSwsFrameContext, - AVCOL_SPC_RGB, - /*outputFormat=*/AV_PIX_FMT_RGB24, - /*swsFlags=*/swsFlags_); - prevResizeSwsFrameContext_ = resizeSwsFrameContext; - } - - uint8_t* srcPointers[4] = { - colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; - int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0}; - - uint8_t* dstPointers[4] = { - outputTensor.data_ptr(), nullptr, nullptr, nullptr}; - int expectedOutputWidth = static_cast(outputTensor.sizes()[1]); - int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; - - colorConvertedHeight = sws_scale( - resizeSwsContext_.get(), - srcPointers, - srcLinesizes, - 0, - avFrame->height, - dstPointers, - dstLinesizes); - } - - return colorConvertedHeight; -} - torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, const FrameDims& outputDims) { diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index ac853947a..04b33312c 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -9,6 +9,7 @@ #include "DeviceInterface.h" #include "FFMPEGCommon.h" #include "FilterGraph.h" +#include "SwScale.h" namespace facebook::torchcodec { @@ -56,11 +57,6 @@ class CpuDeviceInterface : public DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor); - int convertAVFrameToTensorUsingSwScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor, - const FrameDims& outputDims); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, const FrameDims& outputDims); @@ -80,32 +76,18 @@ class CpuDeviceInterface : public DeviceInterface { // resolutions. std::optional resizedOutputDims_; - // Color-conversion objects. Only one of filterGraph_ and swsContext_ should + // Color-conversion objects. Only one of filterGraph_ and swScale_ should // be non-null. Which one we use is determined dynamically in // getColorConversionLibrary() each time we decode a frame. // - // Creating both filterGraph_ and swsContext_ is relatively expensive, so we - // reuse them across frames. However, it is possbile that subsequent frames + // Creating both filterGraph_ and swScale_ is relatively expensive, so we + // reuse them across frames. However, it is possible that subsequent frames // are different enough (change in dimensions) that we can't reuse the color - // conversion object. We store the relevant frame context from the frame used - // to create the object last time. We always compare the current frame's info - // against the previous one to determine if we need to recreate the color - // conversion object. - // - // TODO: The names of these fields is confusing, as the actual color - // conversion object for Sws has "context" in the name, and we use - // "context" for the structs we store to know if we need to recreate a - // color conversion object. We should clean that up. + // conversion object. SwScale and FilterGraph handle context caching + // internally. std::unique_ptr filterGraph_; FiltersContext prevFiltersContext_; - UniqueSwsContext swsContext_; - SwsFrameContext prevSwsFrameContext_; - - // Cached swscale context for resizing in RGB24 space (used in double swscale - // path). Like the color conversion context above, we cache this to avoid - // recreating it for every frame. - UniqueSwsContext resizeSwsContext_; - SwsFrameContext prevResizeSwsFrameContext_; + std::unique_ptr swScale_; // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline // of what FFmpeg calls "filters" to apply to decoded frames before returning diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp new file mode 100644 index 000000000..779a2b30d --- /dev/null +++ b/src/torchcodec/_core/SwScale.cpp @@ -0,0 +1,125 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "SwScale.h" +#include "Frame.h" + +namespace facebook::torchcodec { + +SwScale::SwScale(int swsFlags) : swsFlags_(swsFlags) {} + +int SwScale::convert( + const UniqueAVFrame& avFrame, + torch::Tensor& outputTensor, + const FrameDims& outputDims) { + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + bool needsResize = + (avFrame->height != outputDims.height || + avFrame->width != outputDims.width); + + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + SwsFrameContext colorConversionFrameContext( + avFrame->width, + avFrame->height, + frameFormat, + needsResize ? avFrame->width : outputDims.width, + needsResize ? avFrame->height : outputDims.height); + + if (!colorConversionSwsContext_ || + prevColorConversionFrameContext_ != colorConversionFrameContext) { + colorConversionSwsContext_ = createSwsContext( + colorConversionFrameContext, + avFrame->colorspace, + + // See [Transform and Format Conversion Order] for more on the output + // pixel format. + /*outputFormat=*/AV_PIX_FMT_RGB24, + + // No flags for color conversion. When resizing is needed, we use a + // separate swscale context with the appropriate resize flags. + /*swsFlags=*/0); + prevColorConversionFrameContext_ = colorConversionFrameContext; + } + + // When no resize is needed, we do color conversion directly into the output + // tensor. + + torch::Tensor colorConvertedTensor = needsResize + ? allocateEmptyHWCTensor( + FrameDims(avFrame->height, avFrame->width), torch::kCPU) + : outputTensor; + + uint8_t* colorConvertedPointers[4] = { + colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; + int colorConvertedWidth = static_cast(colorConvertedTensor.sizes()[1]); + int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0}; + + int colorConvertedHeight = sws_scale( + colorConversionSwsContext_.get(), + avFrame->data, + avFrame->linesize, + 0, + avFrame->height, + colorConvertedPointers, + colorConvertedLinesizes); + + TORCH_CHECK( + colorConvertedHeight == avFrame->height, + "Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ", + colorConvertedHeight, + " != ", + avFrame->height); + + if (needsResize) { + // Use cached swscale context for resizing, similar to the color conversion + // context caching above. + SwsFrameContext resizeFrameContext( + avFrame->width, + avFrame->height, + AV_PIX_FMT_RGB24, + outputDims.width, + outputDims.height); + + if (!resizeSwsContext_ || prevResizeFrameContext_ != resizeFrameContext) { + resizeSwsContext_ = createSwsContext( + resizeFrameContext, + AVCOL_SPC_RGB, + /*outputFormat=*/AV_PIX_FMT_RGB24, + /*swsFlags=*/swsFlags_); + prevResizeFrameContext_ = resizeFrameContext; + } + + uint8_t* srcPointers[4] = { + colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; + int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0}; + + uint8_t* dstPointers[4] = { + outputTensor.data_ptr(), nullptr, nullptr, nullptr}; + int expectedOutputWidth = static_cast(outputTensor.sizes()[1]); + int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; + + colorConvertedHeight = sws_scale( + resizeSwsContext_.get(), + srcPointers, + srcLinesizes, + 0, + avFrame->height, + dstPointers, + dstLinesizes); + } + + return colorConvertedHeight; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SwScale.h b/src/torchcodec/_core/SwScale.h new file mode 100644 index 000000000..813c6be7a --- /dev/null +++ b/src/torchcodec/_core/SwScale.h @@ -0,0 +1,45 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include "FFMPEGCommon.h" + +namespace facebook::torchcodec { + +struct FrameDims; + +// SwScale uses a double swscale path: +// 1. Color conversion (e.g., YUV -> RGB24) at the original frame resolution +// 2. Resize in RGB24 space (if resizing is needed) +// +// This approach ensures that transforms happen in the output color space +// (RGB24) rather than the input color space (YUV). +class SwScale { + public: + explicit SwScale(int swsFlags = SWS_BILINEAR); + + int convert( + const UniqueAVFrame& avFrame, + torch::Tensor& outputTensor, + const FrameDims& outputDims); + + private: + int swsFlags_; + + // Color conversion context (YUV -> RGB24). We cache this to avoid + // recreating it for every frame. + UniqueSwsContext colorConversionSwsContext_; + SwsFrameContext prevColorConversionFrameContext_; + + // Resize context (RGB24 -> RGB24 at different resolution). We cache this + // to avoid recreating it for every frame. + UniqueSwsContext resizeSwsContext_; + SwsFrameContext prevResizeFrameContext_; +}; + +} // namespace facebook::torchcodec From c5d680722ba9849f6bcdb2b6df21cc8f69cd0ecf Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Mon, 2 Feb 2026 08:22:32 -0800 Subject: [PATCH 02/10] fix format --- src/torchcodec/_core/CpuDeviceInterface.cpp | 19 ++- src/torchcodec/_core/CpuDeviceInterface.h | 9 +- src/torchcodec/_core/SwScale.cpp | 133 +++++++++++--------- src/torchcodec/_core/SwScale.h | 41 ++++-- 4 files changed, 127 insertions(+), 75 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 1f1c7d3d9..fe5fbdec5 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -198,10 +198,23 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( outputTensor = preAllocatedOutputTensor.value_or( allocateEmptyHWCTensor(outputDims, torch::kCPU)); - if (!swScale_) { - swScale_ = std::make_unique(swsFlags_); + enum AVPixelFormat avFrameFormat = + static_cast(avFrame->format); + + SwScaleContext swScaleContext( + avFrame->width, + avFrame->height, + avFrameFormat, + avFrame->colorspace, + outputDims.width, + outputDims.height); + + if (!swScale_ || prevSwScaleContext_ != swScaleContext) { + swScale_ = std::make_unique(swScaleContext, swsFlags_); + prevSwScaleContext_ = swScaleContext; } - int resultHeight = swScale_->convert(avFrame, outputTensor, outputDims); + + int resultHeight = swScale_->convert(avFrame, outputTensor); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 04b33312c..fa6c4376a 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -83,11 +83,16 @@ class CpuDeviceInterface : public DeviceInterface { // Creating both filterGraph_ and swScale_ is relatively expensive, so we // reuse them across frames. However, it is possible that subsequent frames // are different enough (change in dimensions) that we can't reuse the color - // conversion object. SwScale and FilterGraph handle context caching - // internally. + // conversion object. + // + // TODO: The names of these fields is confusing, as the actual color + // conversion object for Sws has "context" in the name, and we use + // "context" for the structs we store to know if we need to recreate a + // color conversion object. We should clean that up. std::unique_ptr filterGraph_; FiltersContext prevFiltersContext_; std::unique_ptr swScale_; + SwScaleContext prevSwScaleContext_; // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline // of what FFmpeg calls "filters" to apply to decoded frames before returning diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp index 779a2b30d..f9bbbc437 100644 --- a/src/torchcodec/_core/SwScale.cpp +++ b/src/torchcodec/_core/SwScale.cpp @@ -9,55 +9,88 @@ namespace facebook::torchcodec { -SwScale::SwScale(int swsFlags) : swsFlags_(swsFlags) {} +SwScaleContext::SwScaleContext( + int inputWidth, + int inputHeight, + AVPixelFormat inputFormat, + AVColorSpace inputColorspace, + int outputWidth, + int outputHeight) + : inputWidth(inputWidth), + inputHeight(inputHeight), + inputFormat(inputFormat), + inputColorspace(inputColorspace), + outputWidth(outputWidth), + outputHeight(outputHeight) {} + +bool SwScaleContext::operator==(const SwScaleContext& other) const { + return inputWidth == other.inputWidth && inputHeight == other.inputHeight && + inputFormat == other.inputFormat && + inputColorspace == other.inputColorspace && + outputWidth == other.outputWidth && outputHeight == other.outputHeight; +} -int SwScale::convert( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor, - const FrameDims& outputDims) { - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); +bool SwScaleContext::operator!=(const SwScaleContext& other) const { + return !(*this == other); +} +SwScale::SwScale(const SwScaleContext& context, int swsFlags) + : context_(context), swsFlags_(swsFlags) { bool needsResize = - (avFrame->height != outputDims.height || - avFrame->width != outputDims.width); - - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. + (context_.inputHeight != context_.outputHeight || + context_.inputWidth != context_.outputWidth); + + // Create color conversion context (input format -> RGB24). + // When resizing is needed, color conversion outputs at input resolution. + // When no resize is needed, color conversion outputs at output resolution. SwsFrameContext colorConversionFrameContext( - avFrame->width, - avFrame->height, - frameFormat, - needsResize ? avFrame->width : outputDims.width, - needsResize ? avFrame->height : outputDims.height); - - if (!colorConversionSwsContext_ || - prevColorConversionFrameContext_ != colorConversionFrameContext) { - colorConversionSwsContext_ = createSwsContext( - colorConversionFrameContext, - avFrame->colorspace, - - // See [Transform and Format Conversion Order] for more on the output - // pixel format. - /*outputFormat=*/AV_PIX_FMT_RGB24, + context_.inputWidth, + context_.inputHeight, + context_.inputFormat, + needsResize ? context_.inputWidth : context_.outputWidth, + needsResize ? context_.inputHeight : context_.outputHeight); + + colorConversionSwsContext_ = createSwsContext( + colorConversionFrameContext, + context_.inputColorspace, + // See [Transform and Format Conversion Order] for more on the output + // pixel format. + /*outputFormat=*/AV_PIX_FMT_RGB24, + // No flags for color conversion. When resizing is needed, we use a + // separate swscale context with the appropriate resize flags. + /*swsFlags=*/0); + + // Create resize context if needed (RGB24 at input resolution -> RGB24 at + // output resolution). + if (needsResize) { + SwsFrameContext resizeFrameContext( + context_.inputWidth, + context_.inputHeight, + AV_PIX_FMT_RGB24, + context_.outputWidth, + context_.outputHeight); - // No flags for color conversion. When resizing is needed, we use a - // separate swscale context with the appropriate resize flags. - /*swsFlags=*/0); - prevColorConversionFrameContext_ = colorConversionFrameContext; + resizeSwsContext_ = createSwsContext( + resizeFrameContext, + AVCOL_SPC_RGB, + /*outputFormat=*/AV_PIX_FMT_RGB24, + /*swsFlags=*/swsFlags_); } +} - // When no resize is needed, we do color conversion directly into the output - // tensor. +int SwScale::convert( + const UniqueAVFrame& avFrame, + torch::Tensor& outputTensor) { + bool needsResize = + (context_.inputHeight != context_.outputHeight || + context_.inputWidth != context_.outputWidth); + // When no resize is needed, we do color conversion directly into the output + // tensor. When resize is needed, we first convert to an intermediate tensor + // at the input resolution, then resize into the output tensor. torch::Tensor colorConvertedTensor = needsResize ? allocateEmptyHWCTensor( - FrameDims(avFrame->height, avFrame->width), torch::kCPU) + FrameDims(context_.inputHeight, context_.inputWidth), torch::kCPU) : outputTensor; uint8_t* colorConvertedPointers[4] = { @@ -82,27 +115,9 @@ int SwScale::convert( avFrame->height); if (needsResize) { - // Use cached swscale context for resizing, similar to the color conversion - // context caching above. - SwsFrameContext resizeFrameContext( - avFrame->width, - avFrame->height, - AV_PIX_FMT_RGB24, - outputDims.width, - outputDims.height); - - if (!resizeSwsContext_ || prevResizeFrameContext_ != resizeFrameContext) { - resizeSwsContext_ = createSwsContext( - resizeFrameContext, - AVCOL_SPC_RGB, - /*outputFormat=*/AV_PIX_FMT_RGB24, - /*swsFlags=*/swsFlags_); - prevResizeFrameContext_ = resizeFrameContext; - } - uint8_t* srcPointers[4] = { colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; - int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0}; + int srcLinesizes[4] = {context_.inputWidth * 3, 0, 0, 0}; uint8_t* dstPointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; @@ -114,7 +129,7 @@ int SwScale::convert( srcPointers, srcLinesizes, 0, - avFrame->height, + context_.inputHeight, dstPointers, dstLinesizes); } diff --git a/src/torchcodec/_core/SwScale.h b/src/torchcodec/_core/SwScale.h index 813c6be7a..ffefbf465 100644 --- a/src/torchcodec/_core/SwScale.h +++ b/src/torchcodec/_core/SwScale.h @@ -13,33 +13,52 @@ namespace facebook::torchcodec { struct FrameDims; +struct SwScaleContext { + int inputWidth = 0; + int inputHeight = 0; + AVPixelFormat inputFormat = AV_PIX_FMT_NONE; + AVColorSpace inputColorspace = AVCOL_SPC_UNSPECIFIED; + int outputWidth = 0; + int outputHeight = 0; + + SwScaleContext() = default; + SwScaleContext( + int inputWidth, + int inputHeight, + AVPixelFormat inputFormat, + AVColorSpace inputColorspace, + int outputWidth, + int outputHeight); + + bool operator==(const SwScaleContext&) const; + bool operator!=(const SwScaleContext&) const; +}; + // SwScale uses a double swscale path: // 1. Color conversion (e.g., YUV -> RGB24) at the original frame resolution // 2. Resize in RGB24 space (if resizing is needed) // // This approach ensures that transforms happen in the output color space // (RGB24) rather than the input color space (YUV). +// +// The caller is responsible for caching SwScale instances and recreating them +// when the context changes, similar to how FilterGraph is managed. class SwScale { public: - explicit SwScale(int swsFlags = SWS_BILINEAR); + SwScale(const SwScaleContext& context, int swsFlags = SWS_BILINEAR); - int convert( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor, - const FrameDims& outputDims); + int convert(const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); private: + SwScaleContext context_; int swsFlags_; - // Color conversion context (YUV -> RGB24). We cache this to avoid - // recreating it for every frame. + // Color conversion context (input format -> RGB24 at original resolution). UniqueSwsContext colorConversionSwsContext_; - SwsFrameContext prevColorConversionFrameContext_; - // Resize context (RGB24 -> RGB24 at different resolution). We cache this - // to avoid recreating it for every frame. + // Resize context (RGB24 -> RGB24 at output resolution). + // May be null if no resize is needed. UniqueSwsContext resizeSwsContext_; - SwsFrameContext prevResizeFrameContext_; }; } // namespace facebook::torchcodec From b5dad8e38842018fb32c5c4ba0182d23749d4fbc Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Mon, 2 Feb 2026 08:34:36 -0800 Subject: [PATCH 03/10] fix format --- src/torchcodec/_core/CpuDeviceInterface.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index fa6c4376a..3113ee450 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -83,6 +83,9 @@ class CpuDeviceInterface : public DeviceInterface { // Creating both filterGraph_ and swScale_ is relatively expensive, so we // reuse them across frames. However, it is possible that subsequent frames // are different enough (change in dimensions) that we can't reuse the color + // conversion object. We store the relevant frame context from the frame used + // to create the object last time. We always compare the current frame's info + // against the previous one to determine if we need to recreate the color // conversion object. // // TODO: The names of these fields is confusing, as the actual color From eb436d85d98be2c596115e31080fcb8c643846fd Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Mon, 2 Feb 2026 08:42:57 -0800 Subject: [PATCH 04/10] move resize --- src/torchcodec/_core/SwScale.cpp | 16 ++++++---------- src/torchcodec/_core/SwScale.h | 1 + 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp index f9bbbc437..28e882071 100644 --- a/src/torchcodec/_core/SwScale.cpp +++ b/src/torchcodec/_core/SwScale.cpp @@ -36,7 +36,7 @@ bool SwScaleContext::operator!=(const SwScaleContext& other) const { SwScale::SwScale(const SwScaleContext& context, int swsFlags) : context_(context), swsFlags_(swsFlags) { - bool needsResize = + needsResize_ = (context_.inputHeight != context_.outputHeight || context_.inputWidth != context_.outputWidth); @@ -47,8 +47,8 @@ SwScale::SwScale(const SwScaleContext& context, int swsFlags) context_.inputWidth, context_.inputHeight, context_.inputFormat, - needsResize ? context_.inputWidth : context_.outputWidth, - needsResize ? context_.inputHeight : context_.outputHeight); + needsResize_ ? context_.inputWidth : context_.outputWidth, + needsResize_ ? context_.inputHeight : context_.outputHeight); colorConversionSwsContext_ = createSwsContext( colorConversionFrameContext, @@ -62,7 +62,7 @@ SwScale::SwScale(const SwScaleContext& context, int swsFlags) // Create resize context if needed (RGB24 at input resolution -> RGB24 at // output resolution). - if (needsResize) { + if (needsResize_) { SwsFrameContext resizeFrameContext( context_.inputWidth, context_.inputHeight, @@ -81,14 +81,10 @@ SwScale::SwScale(const SwScaleContext& context, int swsFlags) int SwScale::convert( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor) { - bool needsResize = - (context_.inputHeight != context_.outputHeight || - context_.inputWidth != context_.outputWidth); - // When no resize is needed, we do color conversion directly into the output // tensor. When resize is needed, we first convert to an intermediate tensor // at the input resolution, then resize into the output tensor. - torch::Tensor colorConvertedTensor = needsResize + torch::Tensor colorConvertedTensor = needsResize_ ? allocateEmptyHWCTensor( FrameDims(context_.inputHeight, context_.inputWidth), torch::kCPU) : outputTensor; @@ -114,7 +110,7 @@ int SwScale::convert( " != ", avFrame->height); - if (needsResize) { + if (needsResize_) { uint8_t* srcPointers[4] = { colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; int srcLinesizes[4] = {context_.inputWidth * 3, 0, 0, 0}; diff --git a/src/torchcodec/_core/SwScale.h b/src/torchcodec/_core/SwScale.h index ffefbf465..6c929309b 100644 --- a/src/torchcodec/_core/SwScale.h +++ b/src/torchcodec/_core/SwScale.h @@ -52,6 +52,7 @@ class SwScale { private: SwScaleContext context_; int swsFlags_; + bool needsResize_; // Color conversion context (input format -> RGB24 at original resolution). UniqueSwsContext colorConversionSwsContext_; From a18366772cfbcb746110c1b26376867fc2c80c6e Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 11 Feb 2026 13:28:31 -0800 Subject: [PATCH 05/10] address feedback --- .../_core/BetaCudaDeviceInterface.cpp | 11 +-- .../_core/BetaCudaDeviceInterface.h | 2 +- src/torchcodec/_core/CpuDeviceInterface.cpp | 15 ++-- src/torchcodec/_core/CpuDeviceInterface.h | 10 +-- src/torchcodec/_core/CudaDeviceInterface.cpp | 16 ++-- src/torchcodec/_core/CudaDeviceInterface.h | 2 +- src/torchcodec/_core/FFMPEGCommon.cpp | 29 ++++--- src/torchcodec/_core/FFMPEGCommon.h | 15 ++-- src/torchcodec/_core/FilterGraph.cpp | 30 +++---- src/torchcodec/_core/FilterGraph.h | 16 ++-- src/torchcodec/_core/SwScale.cpp | 81 +++++++------------ src/torchcodec/_core/SwScale.h | 29 ++----- 12 files changed, 108 insertions(+), 148 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 45f6ba1a5..3f05b6656 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -704,17 +704,18 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( "Failed to allocate NV12 CPU frame buffer: ", getFFMPEGErrorStringFromErrorCode(ret)); - SwsFrameContext swsFrameContext( + SwsFrameConfig swsFrameConfig( width, height, static_cast(cpuFrame->format), + cpuFrame->colorspace, width, height); - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - swsContext_ = createSwsContext( - swsFrameContext, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR); - prevSwsFrameContext_ = swsFrameContext; + if (!swsContext_ || prevSwsFrameConfig_ != swsFrameConfig) { + swsContext_ = + createSwsContext(swsFrameConfig, AV_PIX_FMT_NV12, SWS_BILINEAR); + prevSwsFrameConfig_ = swsFrameConfig; } int convertedHeight = sws_scale( diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index cefb1a983..12b8bbd2c 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -101,7 +101,7 @@ class BetaCudaDeviceInterface : public DeviceInterface { std::unique_ptr cpuFallback_; bool nvcuvidAvailable_ = false; UniqueSwsContext swsContext_; - SwsFrameContext prevSwsFrameContext_; + SwsFrameConfig prevSwsFrameConfig_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index fe5fbdec5..19bbdb7da 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -201,7 +201,7 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( enum AVPixelFormat avFrameFormat = static_cast(avFrame->format); - SwScaleContext swScaleContext( + SwsFrameConfig swsFrameConfig( avFrame->width, avFrame->height, avFrameFormat, @@ -209,9 +209,8 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( outputDims.width, outputDims.height); - if (!swScale_ || prevSwScaleContext_ != swScaleContext) { - swScale_ = std::make_unique(swScaleContext, swsFlags_); - prevSwScaleContext_ = swScaleContext; + if (!swScale_ || swScale_->getConfig() != swsFrameConfig) { + swScale_ = std::make_unique(swsFrameConfig, swsFlags_); } int resultHeight = swScale_->convert(avFrame, outputTensor); @@ -265,7 +264,7 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( enum AVPixelFormat avFrameFormat = static_cast(avFrame->format); - FiltersContext filtersContext( + FiltersConfig filtersConfig( avFrame->width, avFrame->height, avFrameFormat, @@ -276,10 +275,10 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( filters_, timeBase_); - if (!filterGraph_ || prevFiltersContext_ != filtersContext) { + if (!filterGraph_ || prevFiltersConfig_ != filtersConfig) { filterGraph_ = - std::make_unique(filtersContext, videoStreamOptions_); - prevFiltersContext_ = std::move(filtersContext); + std::make_unique(filtersConfig, videoStreamOptions_); + prevFiltersConfig_ = std::move(filtersConfig); } return rgbAVFrameToTensor(filterGraph_->convert(avFrame)); } diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 3113ee450..803c06686 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -83,19 +83,13 @@ class CpuDeviceInterface : public DeviceInterface { // Creating both filterGraph_ and swScale_ is relatively expensive, so we // reuse them across frames. However, it is possible that subsequent frames // are different enough (change in dimensions) that we can't reuse the color - // conversion object. We store the relevant frame context from the frame used + // conversion object. We store the relevant frame config from the frame used // to create the object last time. We always compare the current frame's info // against the previous one to determine if we need to recreate the color // conversion object. - // - // TODO: The names of these fields is confusing, as the actual color - // conversion object for Sws has "context" in the name, and we use - // "context" for the structs we store to know if we need to recreate a - // color conversion object. We should clean that up. std::unique_ptr filterGraph_; - FiltersContext prevFiltersContext_; + FiltersConfig prevFiltersConfig_; std::unique_ptr swScale_; - SwScaleContext prevSwScaleContext_; // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline // of what FFmpeg calls "filters" to apply to decoded frames before returning diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 11284fbc9..fc7d901c7 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -200,7 +200,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( enum AVPixelFormat frameFormat = static_cast(avFrame->format); - auto newContext = std::make_unique( + auto newConfig = std::make_unique( avFrame->width, avFrame->height, frameFormat, @@ -212,22 +212,22 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( timeBase_, av_buffer_ref(avFrame->hw_frames_ctx)); - if (!nv12Conversion_ || *nv12ConversionContext_ != *newContext) { + if (!nv12Conversion_ || *nv12ConversionConfig_ != *newConfig) { nv12Conversion_ = - std::make_unique(*newContext, videoStreamOptions_); - nv12ConversionContext_ = std::move(newContext); + std::make_unique(*newConfig, videoStreamOptions_); + nv12ConversionConfig_ = std::move(newConfig); } auto filteredAVFrame = nv12Conversion_->convert(avFrame); // If this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. TORCH_CHECK( - (filteredAVFrame->width == nv12ConversionContext_->outputWidth) && - (filteredAVFrame->height == nv12ConversionContext_->outputHeight), + (filteredAVFrame->width == nv12ConversionConfig_->outputWidth) && + (filteredAVFrame->height == nv12ConversionConfig_->outputHeight), "Expected frame from filter graph of ", - nv12ConversionContext_->outputWidth, + nv12ConversionConfig_->outputWidth, "x", - nv12ConversionContext_->outputHeight, + nv12ConversionConfig_->outputHeight, ", got ", filteredAVFrame->width, "x", diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 267127c68..92378bfb3 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -69,7 +69,7 @@ class CudaDeviceInterface : public DeviceInterface { // This filtergraph instance is only used for NV12 format conversion in // maybeConvertAVFrameToNV12(). - std::unique_ptr nv12ConversionContext_; + std::unique_ptr nv12ConversionConfig_; std::unique_ptr nv12Conversion_; bool usingCPUFallback_ = false; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index d2caf3460..b6d87e7d2 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -616,39 +616,41 @@ int64_t computeSafeDuration( } } -SwsFrameContext::SwsFrameContext( +SwsFrameConfig::SwsFrameConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, + AVColorSpace inputColorspace, int outputWidth, int outputHeight) : inputWidth(inputWidth), inputHeight(inputHeight), inputFormat(inputFormat), + inputColorspace(inputColorspace), outputWidth(outputWidth), outputHeight(outputHeight) {} -bool SwsFrameContext::operator==(const SwsFrameContext& other) const { +bool SwsFrameConfig::operator==(const SwsFrameConfig& other) const { return inputWidth == other.inputWidth && inputHeight == other.inputHeight && - inputFormat == other.inputFormat && outputWidth == other.outputWidth && - outputHeight == other.outputHeight; + inputFormat == other.inputFormat && + inputColorspace == other.inputColorspace && + outputWidth == other.outputWidth && outputHeight == other.outputHeight; } -bool SwsFrameContext::operator!=(const SwsFrameContext& other) const { +bool SwsFrameConfig::operator!=(const SwsFrameConfig& other) const { return !(*this == other); } UniqueSwsContext createSwsContext( - const SwsFrameContext& swsFrameContext, - AVColorSpace colorspace, + const SwsFrameConfig& swsFrameConfig, AVPixelFormat outputFormat, int swsFlags) { SwsContext* swsContext = sws_getContext( - swsFrameContext.inputWidth, - swsFrameContext.inputHeight, - swsFrameContext.inputFormat, - swsFrameContext.outputWidth, - swsFrameContext.outputHeight, + swsFrameConfig.inputWidth, + swsFrameConfig.inputHeight, + swsFrameConfig.inputFormat, + swsFrameConfig.outputWidth, + swsFrameConfig.outputHeight, outputFormat, swsFlags, nullptr, @@ -670,7 +672,8 @@ UniqueSwsContext createSwsContext( &saturation); TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); - const int* colorspaceTable = sws_getCoefficients(colorspace); + const int* colorspaceTable = + sws_getCoefficients(swsFrameConfig.inputColorspace); ret = sws_setColorspaceDetails( swsContext, colorspaceTable, diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 8082b9659..d860fb890 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -285,29 +285,30 @@ AVFilterContext* createAVFilterContextWithOptions( const AVFilter* buffer, const enum AVPixelFormat outputFormat); -struct SwsFrameContext { +struct SwsFrameConfig { int inputWidth = 0; int inputHeight = 0; AVPixelFormat inputFormat = AV_PIX_FMT_NONE; + AVColorSpace inputColorspace = AVCOL_SPC_UNSPECIFIED; int outputWidth = 0; int outputHeight = 0; - SwsFrameContext() = default; - SwsFrameContext( + SwsFrameConfig() = default; + SwsFrameConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, + AVColorSpace inputColorspace, int outputWidth, int outputHeight); - bool operator==(const SwsFrameContext& other) const; - bool operator!=(const SwsFrameContext& other) const; + bool operator==(const SwsFrameConfig& other) const; + bool operator!=(const SwsFrameConfig& other) const; }; // Utility functions for swscale context management UniqueSwsContext createSwsContext( - const SwsFrameContext& swsFrameContext, - AVColorSpace colorspace, + const SwsFrameConfig& swsFrameConfig, AVPixelFormat outputFormat, int swsFlags); diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index d0d03b571..9a3b24ffc 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -14,7 +14,7 @@ extern "C" { namespace facebook::torchcodec { -FiltersContext::FiltersContext( +FiltersConfig::FiltersConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, @@ -40,7 +40,7 @@ bool operator==(const AVRational& lhs, const AVRational& rhs) { return lhs.num == rhs.num && lhs.den == rhs.den; } -bool FiltersContext::operator==(const FiltersContext& other) const { +bool FiltersConfig::operator==(const FiltersConfig& other) const { return inputWidth == other.inputWidth && inputHeight == other.inputHeight && inputFormat == other.inputFormat && outputWidth == other.outputWidth && outputHeight == other.outputHeight && @@ -49,12 +49,12 @@ bool FiltersContext::operator==(const FiltersContext& other) const { hwFramesCtx.get() == other.hwFramesCtx.get(); } -bool FiltersContext::operator!=(const FiltersContext& other) const { +bool FiltersConfig::operator!=(const FiltersConfig& other) const { return !(*this == other); } FilterGraph::FilterGraph( - const FiltersContext& filtersContext, + const FiltersConfig& filtersConfig, const VideoStreamOptions& videoStreamOptions) { filterGraph_.reset(avfilter_graph_alloc()); TORCH_CHECK(filterGraph_.get() != nullptr); @@ -68,13 +68,13 @@ FilterGraph::FilterGraph( UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc()); TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); - srcParams->format = filtersContext.inputFormat; - srcParams->width = filtersContext.inputWidth; - srcParams->height = filtersContext.inputHeight; - srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio; - srcParams->time_base = filtersContext.timeBase; - if (filtersContext.hwFramesCtx) { - srcParams->hw_frames_ctx = av_buffer_ref(filtersContext.hwFramesCtx.get()); + srcParams->format = filtersConfig.inputFormat; + srcParams->width = filtersConfig.inputWidth; + srcParams->height = filtersConfig.inputHeight; + srcParams->sample_aspect_ratio = filtersConfig.inputAspectRatio; + srcParams->time_base = filtersConfig.timeBase; + if (filtersConfig.hwFramesCtx) { + srcParams->hw_frames_ctx = av_buffer_ref(filtersConfig.hwFramesCtx.get()); } sourceContext_ = @@ -98,7 +98,7 @@ FilterGraph::FilterGraph( TORCH_CHECK(bufferSink != nullptr, "Failed to get buffersink filter."); sinkContext_ = createAVFilterContextWithOptions( - filterGraph_.get(), bufferSink, filtersContext.outputFormat); + filterGraph_.get(), bufferSink, filtersConfig.outputFormat); TORCH_CHECK( sinkContext_ != nullptr, "Failed to create and configure buffersink"); @@ -122,7 +122,7 @@ FilterGraph::FilterGraph( AVFilterInOut* inputsTmp = inputs.release(); status = avfilter_graph_parse_ptr( filterGraph_.get(), - filtersContext.filtergraphStr.c_str(), + filtersConfig.filtergraphStr.c_str(), &inputsTmp, &outputsTmp, nullptr); @@ -132,7 +132,7 @@ FilterGraph::FilterGraph( status >= 0, "Failed to parse filter description: ", getFFMPEGErrorStringFromErrorCode(status), - ", provided filters: " + filtersContext.filtergraphStr); + ", provided filters: " + filtersConfig.filtergraphStr); // Check filtergraph validity and configure links and formats. status = avfilter_graph_config(filterGraph_.get(), nullptr); @@ -140,7 +140,7 @@ FilterGraph::FilterGraph( status >= 0, "Failed to configure filter graph: ", getFFMPEGErrorStringFromErrorCode(status), - ", provided filters: " + filtersContext.filtergraphStr); + ", provided filters: " + filtersConfig.filtergraphStr); } UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) { diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h index 12475fe47..4e5257d79 100644 --- a/src/torchcodec/_core/FilterGraph.h +++ b/src/torchcodec/_core/FilterGraph.h @@ -11,7 +11,7 @@ namespace facebook::torchcodec { -struct FiltersContext { +struct FiltersConfig { int inputWidth = 0; int inputHeight = 0; AVPixelFormat inputFormat = AV_PIX_FMT_NONE; @@ -23,10 +23,10 @@ struct FiltersContext { AVRational timeBase = {0, 0}; UniqueAVBufferRef hwFramesCtx; - FiltersContext() = default; - FiltersContext(FiltersContext&&) = default; - FiltersContext& operator=(FiltersContext&&) = default; - FiltersContext( + FiltersConfig() = default; + FiltersConfig(FiltersConfig&&) = default; + FiltersConfig& operator=(FiltersConfig&&) = default; + FiltersConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, @@ -38,14 +38,14 @@ struct FiltersContext { AVRational timeBase, AVBufferRef* hwFramesCtx = nullptr); - bool operator==(const FiltersContext&) const; - bool operator!=(const FiltersContext&) const; + bool operator==(const FiltersConfig&) const; + bool operator!=(const FiltersConfig&) const; }; class FilterGraph { public: FilterGraph( - const FiltersContext& filtersContext, + const FiltersConfig& filtersConfig, const VideoStreamOptions& videoStreamOptions); UniqueAVFrame convert(const UniqueAVFrame& avFrame); diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp index 28e882071..718a407db 100644 --- a/src/torchcodec/_core/SwScale.cpp +++ b/src/torchcodec/_core/SwScale.cpp @@ -9,50 +9,25 @@ namespace facebook::torchcodec { -SwScaleContext::SwScaleContext( - int inputWidth, - int inputHeight, - AVPixelFormat inputFormat, - AVColorSpace inputColorspace, - int outputWidth, - int outputHeight) - : inputWidth(inputWidth), - inputHeight(inputHeight), - inputFormat(inputFormat), - inputColorspace(inputColorspace), - outputWidth(outputWidth), - outputHeight(outputHeight) {} - -bool SwScaleContext::operator==(const SwScaleContext& other) const { - return inputWidth == other.inputWidth && inputHeight == other.inputHeight && - inputFormat == other.inputFormat && - inputColorspace == other.inputColorspace && - outputWidth == other.outputWidth && outputHeight == other.outputHeight; -} - -bool SwScaleContext::operator!=(const SwScaleContext& other) const { - return !(*this == other); -} - -SwScale::SwScale(const SwScaleContext& context, int swsFlags) - : context_(context), swsFlags_(swsFlags) { +SwScale::SwScale(const SwsFrameConfig& config, int swsFlags) + : config_(config), swsFlags_(swsFlags) { needsResize_ = - (context_.inputHeight != context_.outputHeight || - context_.inputWidth != context_.outputWidth); + (config_.inputHeight != config_.outputHeight || + config_.inputWidth != config_.outputWidth); // Create color conversion context (input format -> RGB24). - // When resizing is needed, color conversion outputs at input resolution. - // When no resize is needed, color conversion outputs at output resolution. - SwsFrameContext colorConversionFrameContext( - context_.inputWidth, - context_.inputHeight, - context_.inputFormat, - needsResize_ ? context_.inputWidth : context_.outputWidth, - needsResize_ ? context_.inputHeight : context_.outputHeight); + // Color conversion always outputs at the input resolution. + // When no resize is needed, input and output resolutions are the same. + SwsFrameConfig colorConversionFrameConfig( + config_.inputWidth, + config_.inputHeight, + config_.inputFormat, + config_.inputColorspace, + config_.inputWidth, + config_.inputHeight); colorConversionSwsContext_ = createSwsContext( - colorConversionFrameContext, - context_.inputColorspace, + colorConversionFrameConfig, // See [Transform and Format Conversion Order] for more on the output // pixel format. /*outputFormat=*/AV_PIX_FMT_RGB24, @@ -63,16 +38,16 @@ SwScale::SwScale(const SwScaleContext& context, int swsFlags) // Create resize context if needed (RGB24 at input resolution -> RGB24 at // output resolution). if (needsResize_) { - SwsFrameContext resizeFrameContext( - context_.inputWidth, - context_.inputHeight, + SwsFrameConfig resizeFrameConfig( + config_.inputWidth, + config_.inputHeight, AV_PIX_FMT_RGB24, - context_.outputWidth, - context_.outputHeight); + AVCOL_SPC_RGB, + config_.outputWidth, + config_.outputHeight); resizeSwsContext_ = createSwsContext( - resizeFrameContext, - AVCOL_SPC_RGB, + resizeFrameConfig, /*outputFormat=*/AV_PIX_FMT_RGB24, /*swsFlags=*/swsFlags_); } @@ -81,12 +56,16 @@ SwScale::SwScale(const SwScaleContext& context, int swsFlags) int SwScale::convert( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor) { + // When resizing is needed, we do sws_scale twice: first convert to RGB24 at + // original resolution, then resize in RGB24 space. This ensures transforms + // happen in the output color space (RGB24) rather than the input color space + // (YUV). + // // When no resize is needed, we do color conversion directly into the output - // tensor. When resize is needed, we first convert to an intermediate tensor - // at the input resolution, then resize into the output tensor. + // tensor. torch::Tensor colorConvertedTensor = needsResize_ ? allocateEmptyHWCTensor( - FrameDims(context_.inputHeight, context_.inputWidth), torch::kCPU) + FrameDims(config_.inputHeight, config_.inputWidth), torch::kCPU) : outputTensor; uint8_t* colorConvertedPointers[4] = { @@ -113,7 +92,7 @@ int SwScale::convert( if (needsResize_) { uint8_t* srcPointers[4] = { colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; - int srcLinesizes[4] = {context_.inputWidth * 3, 0, 0, 0}; + int srcLinesizes[4] = {config_.inputWidth * 3, 0, 0, 0}; uint8_t* dstPointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; @@ -125,7 +104,7 @@ int SwScale::convert( srcPointers, srcLinesizes, 0, - context_.inputHeight, + config_.inputHeight, dstPointers, dstLinesizes); } diff --git a/src/torchcodec/_core/SwScale.h b/src/torchcodec/_core/SwScale.h index 6c929309b..60eedc113 100644 --- a/src/torchcodec/_core/SwScale.h +++ b/src/torchcodec/_core/SwScale.h @@ -13,27 +13,6 @@ namespace facebook::torchcodec { struct FrameDims; -struct SwScaleContext { - int inputWidth = 0; - int inputHeight = 0; - AVPixelFormat inputFormat = AV_PIX_FMT_NONE; - AVColorSpace inputColorspace = AVCOL_SPC_UNSPECIFIED; - int outputWidth = 0; - int outputHeight = 0; - - SwScaleContext() = default; - SwScaleContext( - int inputWidth, - int inputHeight, - AVPixelFormat inputFormat, - AVColorSpace inputColorspace, - int outputWidth, - int outputHeight); - - bool operator==(const SwScaleContext&) const; - bool operator!=(const SwScaleContext&) const; -}; - // SwScale uses a double swscale path: // 1. Color conversion (e.g., YUV -> RGB24) at the original frame resolution // 2. Resize in RGB24 space (if resizing is needed) @@ -45,12 +24,16 @@ struct SwScaleContext { // when the context changes, similar to how FilterGraph is managed. class SwScale { public: - SwScale(const SwScaleContext& context, int swsFlags = SWS_BILINEAR); + SwScale(const SwsFrameConfig& config, int swsFlags = SWS_BILINEAR); int convert(const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); + const SwsFrameConfig& getConfig() const { + return config_; + } + private: - SwScaleContext context_; + SwsFrameConfig config_; int swsFlags_; bool needsResize_; From d7e115d090cd417074c88325ac9c6823220db6ef Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 11 Feb 2026 14:37:18 -0800 Subject: [PATCH 06/10] encoding swscale --- src/torchcodec/_core/CpuDeviceInterface.cpp | 9 +++++---- src/torchcodec/_core/CpuDeviceInterface.h | 4 ++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index acc899a76..832afa9d5 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -418,8 +418,8 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding( AVPixelFormat outPixelFormat = codecContext->pix_fmt; // Initialize and cache scaling context if it does not exist - if (!swsContext_) { - swsContext_.reset(sws_getContext( + if (!encodingSwsContext_) { + encodingSwsContext_.reset(sws_getContext( inWidth, inHeight, inPixelFormat, @@ -430,7 +430,8 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding( nullptr, nullptr, nullptr)); - STD_TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); + STD_TORCH_CHECK( + encodingSwsContext_ != nullptr, "Failed to create scaling context"); } UniqueAVFrame avFrame(av_frame_alloc()); @@ -469,7 +470,7 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding( inputFrame->linesize[2] = inWidth; status = sws_scale( - swsContext_.get(), + encodingSwsContext_.get(), inputFrame->data, inputFrame->linesize, 0, diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index fe43013c8..8559b5051 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -96,6 +96,10 @@ class CpuDeviceInterface : public DeviceInterface { FiltersConfig prevFiltersConfig_; std::unique_ptr swScale_; + // Cached swscale context for encoding (tensor -> AVFrame pixel format + // conversion). + UniqueSwsContext encodingSwsContext_; + // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline // of what FFmpeg calls "filters" to apply to decoded frames before returning // them. In the PyTorch ecosystem, we call these "transforms". During From 22911a8bb49ce168550714c2ce5fa3ae1d8c9fa5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 13 Feb 2026 10:12:57 +0000 Subject: [PATCH 07/10] empty From 540ec6aeb1e21c9b22d946327874a00ac23a51dc Mon Sep 17 00:00:00 2001 From: Molly Xu <64995721+mollyxu@users.noreply.github.com> Date: Fri, 13 Feb 2026 14:28:47 -0500 Subject: [PATCH 08/10] Update src/torchcodec/_core/SwScale.cpp Co-authored-by: Nicolas Hug --- src/torchcodec/_core/SwScale.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp index 718a407db..74afddddc 100644 --- a/src/torchcodec/_core/SwScale.cpp +++ b/src/torchcodec/_core/SwScale.cpp @@ -82,7 +82,7 @@ int SwScale::convert( colorConvertedPointers, colorConvertedLinesizes); - TORCH_CHECK( + STD_TORCH_CHECK( colorConvertedHeight == avFrame->height, "Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ", colorConvertedHeight, From 489a898cfa9e7b72920486b287ebafc20ff0e014 Mon Sep 17 00:00:00 2001 From: Molly Xu <64995721+mollyxu@users.noreply.github.com> Date: Fri, 13 Feb 2026 14:28:57 -0500 Subject: [PATCH 09/10] Update src/torchcodec/_core/SwScale.cpp Co-authored-by: Nicolas Hug --- src/torchcodec/_core/SwScale.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp index 74afddddc..507aedf13 100644 --- a/src/torchcodec/_core/SwScale.cpp +++ b/src/torchcodec/_core/SwScale.cpp @@ -65,7 +65,7 @@ int SwScale::convert( // tensor. torch::Tensor colorConvertedTensor = needsResize_ ? allocateEmptyHWCTensor( - FrameDims(config_.inputHeight, config_.inputWidth), torch::kCPU) + FrameDims(config_.inputHeight, config_.inputWidth), kStableCPU) : outputTensor; uint8_t* colorConvertedPointers[4] = { From 0580cb1a05a441fb21db812289aa6fe58df79959 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Fri, 13 Feb 2026 11:37:46 -0800 Subject: [PATCH 10/10] change to swsconfig --- .../_core/BetaCudaDeviceInterface.cpp | 9 ++++---- .../_core/BetaCudaDeviceInterface.h | 2 +- src/torchcodec/_core/CpuDeviceInterface.cpp | 6 +++--- src/torchcodec/_core/FFMPEGCommon.cpp | 21 +++++++++---------- src/torchcodec/_core/FFMPEGCommon.h | 12 +++++------ src/torchcodec/_core/SwScale.cpp | 6 +++--- src/torchcodec/_core/SwScale.h | 6 +++--- 7 files changed, 30 insertions(+), 32 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 9d62e84b2..389114df4 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -748,7 +748,7 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( "Failed to allocate NV12 CPU frame buffer: ", getFFMPEGErrorStringFromErrorCode(ret)); - SwsFrameConfig swsFrameConfig( + SwsConfig swsConfig( width, height, static_cast(cpuFrame->format), @@ -756,10 +756,9 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( width, height); - if (!swsContext_ || prevSwsFrameConfig_ != swsFrameConfig) { - swsContext_ = - createSwsContext(swsFrameConfig, AV_PIX_FMT_NV12, SWS_BILINEAR); - prevSwsFrameConfig_ = swsFrameConfig; + if (!swsContext_ || prevSwsConfig_ != swsConfig) { + swsContext_ = createSwsContext(swsConfig, AV_PIX_FMT_NV12, SWS_BILINEAR); + prevSwsConfig_ = swsConfig; } int convertedHeight = sws_scale( diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index fbcd72969..eb51260f8 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -101,7 +101,7 @@ class BetaCudaDeviceInterface : public DeviceInterface { std::unique_ptr cpuFallback_; bool nvcuvidAvailable_ = false; UniqueSwsContext swsContext_; - SwsFrameConfig prevSwsFrameConfig_; + SwsConfig prevSwsConfig_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 75b4f160e..7552b34f3 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -203,7 +203,7 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( enum AVPixelFormat avFrameFormat = static_cast(avFrame->format); - SwsFrameConfig swsFrameConfig( + SwsConfig swsConfig( avFrame->width, avFrame->height, avFrameFormat, @@ -211,8 +211,8 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( outputDims.width, outputDims.height); - if (!swScale_ || swScale_->getConfig() != swsFrameConfig) { - swScale_ = std::make_unique(swsFrameConfig, swsFlags_); + if (!swScale_ || swScale_->getConfig() != swsConfig) { + swScale_ = std::make_unique(swsConfig, swsFlags_); } int resultHeight = swScale_->convert(avFrame, outputTensor); diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 10999028e..e7b3efa35 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -681,7 +681,7 @@ std::optional getRotationFromStream(const AVStream* avStream) { return rotation; } -SwsFrameConfig::SwsFrameConfig( +SwsConfig::SwsConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, @@ -695,27 +695,27 @@ SwsFrameConfig::SwsFrameConfig( outputWidth(outputWidth), outputHeight(outputHeight) {} -bool SwsFrameConfig::operator==(const SwsFrameConfig& other) const { +bool SwsConfig::operator==(const SwsConfig& other) const { return inputWidth == other.inputWidth && inputHeight == other.inputHeight && inputFormat == other.inputFormat && inputColorspace == other.inputColorspace && outputWidth == other.outputWidth && outputHeight == other.outputHeight; } -bool SwsFrameConfig::operator!=(const SwsFrameConfig& other) const { +bool SwsConfig::operator!=(const SwsConfig& other) const { return !(*this == other); } UniqueSwsContext createSwsContext( - const SwsFrameConfig& swsFrameConfig, + const SwsConfig& swsConfig, AVPixelFormat outputFormat, int swsFlags) { SwsContext* swsContext = sws_getContext( - swsFrameConfig.inputWidth, - swsFrameConfig.inputHeight, - swsFrameConfig.inputFormat, - swsFrameConfig.outputWidth, - swsFrameConfig.outputHeight, + swsConfig.inputWidth, + swsConfig.inputHeight, + swsConfig.inputFormat, + swsConfig.outputWidth, + swsConfig.outputHeight, outputFormat, swsFlags, nullptr, @@ -737,8 +737,7 @@ UniqueSwsContext createSwsContext( &saturation); STD_TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); - const int* colorspaceTable = - sws_getCoefficients(swsFrameConfig.inputColorspace); + const int* colorspaceTable = sws_getCoefficients(swsConfig.inputColorspace); ret = sws_setColorspaceDetails( swsContext, colorspaceTable, diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index c8d046d6b..77738c3c6 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -291,7 +291,7 @@ AVFilterContext* createAVFilterContextWithOptions( const AVFilter* buffer, const enum AVPixelFormat outputFormat); -struct SwsFrameConfig { +struct SwsConfig { int inputWidth = 0; int inputHeight = 0; AVPixelFormat inputFormat = AV_PIX_FMT_NONE; @@ -299,8 +299,8 @@ struct SwsFrameConfig { int outputWidth = 0; int outputHeight = 0; - SwsFrameConfig() = default; - SwsFrameConfig( + SwsConfig() = default; + SwsConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, @@ -308,13 +308,13 @@ struct SwsFrameConfig { int outputWidth, int outputHeight); - bool operator==(const SwsFrameConfig& other) const; - bool operator!=(const SwsFrameConfig& other) const; + bool operator==(const SwsConfig& other) const; + bool operator!=(const SwsConfig& other) const; }; // Utility functions for swscale context management UniqueSwsContext createSwsContext( - const SwsFrameConfig& swsFrameConfig, + const SwsConfig& swsConfig, AVPixelFormat outputFormat, int swsFlags); diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp index 507aedf13..9b85680c2 100644 --- a/src/torchcodec/_core/SwScale.cpp +++ b/src/torchcodec/_core/SwScale.cpp @@ -9,7 +9,7 @@ namespace facebook::torchcodec { -SwScale::SwScale(const SwsFrameConfig& config, int swsFlags) +SwScale::SwScale(const SwsConfig& config, int swsFlags) : config_(config), swsFlags_(swsFlags) { needsResize_ = (config_.inputHeight != config_.outputHeight || @@ -18,7 +18,7 @@ SwScale::SwScale(const SwsFrameConfig& config, int swsFlags) // Create color conversion context (input format -> RGB24). // Color conversion always outputs at the input resolution. // When no resize is needed, input and output resolutions are the same. - SwsFrameConfig colorConversionFrameConfig( + SwsConfig colorConversionFrameConfig( config_.inputWidth, config_.inputHeight, config_.inputFormat, @@ -38,7 +38,7 @@ SwScale::SwScale(const SwsFrameConfig& config, int swsFlags) // Create resize context if needed (RGB24 at input resolution -> RGB24 at // output resolution). if (needsResize_) { - SwsFrameConfig resizeFrameConfig( + SwsConfig resizeFrameConfig( config_.inputWidth, config_.inputHeight, AV_PIX_FMT_RGB24, diff --git a/src/torchcodec/_core/SwScale.h b/src/torchcodec/_core/SwScale.h index 60eedc113..24358f787 100644 --- a/src/torchcodec/_core/SwScale.h +++ b/src/torchcodec/_core/SwScale.h @@ -24,16 +24,16 @@ struct FrameDims; // when the context changes, similar to how FilterGraph is managed. class SwScale { public: - SwScale(const SwsFrameConfig& config, int swsFlags = SWS_BILINEAR); + SwScale(const SwsConfig& config, int swsFlags = SWS_BILINEAR); int convert(const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); - const SwsFrameConfig& getConfig() const { + const SwsConfig& getConfig() const { return config_; } private: - SwsFrameConfig config_; + SwsConfig config_; int swsFlags_; bool needsResize_;