diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 786e43301..e58c9963e 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -748,17 +748,17 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( "Failed to allocate NV12 CPU frame buffer: ", getFFMPEGErrorStringFromErrorCode(ret)); - SwsFrameContext swsFrameContext( + SwsConfig swsConfig( width, height, static_cast(cpuFrame->format), + cpuFrame->colorspace, width, height); - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - swsContext_ = createSwsContext( - swsFrameContext, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR); - prevSwsFrameContext_ = swsFrameContext; + if (!swsContext_ || prevSwsConfig_ != swsConfig) { + swsContext_ = createSwsContext(swsConfig, AV_PIX_FMT_NV12, SWS_BILINEAR); + prevSwsConfig_ = swsConfig; } int convertedHeight = sws_scale( diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index 7dda6e235..fbe78bcc5 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -106,7 +106,8 @@ class BetaCudaDeviceInterface : public DeviceInterface { std::unique_ptr cpuFallback_; bool nvcuvidAvailable_ = false; UniqueSwsContext swsContext_; - SwsFrameContext prevSwsFrameContext_; + + SwsConfig prevSwsConfig_; Rotation rotation_ = Rotation::NONE; }; diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 6b0aa0fe2..1a1861ce9 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -136,6 +136,7 @@ function(make_torchcodec_libraries ValidationUtils.cpp Transform.cpp Metadata.cpp + SwScale.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index e4e7c8360..7552b34f3 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -200,8 +200,22 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( outputTensor = preAllocatedOutputTensor.value_or( allocateEmptyHWCTensor(outputDims, kStableCPU)); - int resultHeight = - convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims); + enum AVPixelFormat avFrameFormat = + static_cast(avFrame->format); + + SwsConfig swsConfig( + avFrame->width, + avFrame->height, + avFrameFormat, + avFrame->colorspace, + outputDims.width, + outputDims.height); + + if (!swScale_ || swScale_->getConfig() != swsConfig) { + swScale_ = std::make_unique(swsConfig, swsFlags_); + } + + int resultHeight = swScale_->convert(avFrame, outputTensor); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. @@ -246,129 +260,13 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( } } -int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor, - const FrameDims& outputDims) { - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); - - bool needsResize = - (avFrame->height != outputDims.height || - avFrame->width != outputDims.width); - - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. - SwsFrameContext swsFrameContext( - avFrame->width, - avFrame->height, - frameFormat, - needsResize ? avFrame->width : outputDims.width, - needsResize ? avFrame->height : outputDims.height); - - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - swsContext_ = createSwsContext( - swsFrameContext, - avFrame->colorspace, - - // See [Transform and Format Conversion Order] for more on the output - // pixel format. - /*outputFormat=*/AV_PIX_FMT_RGB24, - - // No flags for color conversion. When resizing is needed, we use a - // separate swscale context with the appropriate resize flags. - /*swsFlags=*/0); - prevSwsFrameContext_ = swsFrameContext; - } - - // When resizing is needed, we do sws_scale twice: first convert to RGB24 at - // original resolution, then resize in RGB24 space. This ensures transforms - // happen in the output color space (RGB24) rather than the input color space - // (YUV). - // - // When no resize is needed, we do color conversion directly into the output - // tensor. - - torch::Tensor colorConvertedTensor = needsResize - ? allocateEmptyHWCTensor( - FrameDims(avFrame->height, avFrame->width), kStableCPU) - : outputTensor; - - uint8_t* colorConvertedPointers[4] = { - colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; - int colorConvertedWidth = static_cast(colorConvertedTensor.sizes()[1]); - int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0}; - - int colorConvertedHeight = sws_scale( - swsContext_.get(), - avFrame->data, - avFrame->linesize, - 0, - avFrame->height, - colorConvertedPointers, - colorConvertedLinesizes); - - STD_TORCH_CHECK( - colorConvertedHeight == avFrame->height, - "Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ", - colorConvertedHeight, - " != ", - avFrame->height); - - if (needsResize) { - // Use cached swscale context for resizing, similar to the color conversion - // context caching above. - SwsFrameContext resizeSwsFrameContext( - avFrame->width, - avFrame->height, - AV_PIX_FMT_RGB24, - outputDims.width, - outputDims.height); - - if (!resizeSwsContext_ || - prevResizeSwsFrameContext_ != resizeSwsFrameContext) { - resizeSwsContext_ = createSwsContext( - resizeSwsFrameContext, - AVCOL_SPC_RGB, - /*outputFormat=*/AV_PIX_FMT_RGB24, - /*swsFlags=*/swsFlags_); - prevResizeSwsFrameContext_ = resizeSwsFrameContext; - } - - uint8_t* srcPointers[4] = { - colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; - int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0}; - - uint8_t* dstPointers[4] = { - outputTensor.data_ptr(), nullptr, nullptr, nullptr}; - int expectedOutputWidth = static_cast(outputTensor.sizes()[1]); - int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; - - colorConvertedHeight = sws_scale( - resizeSwsContext_.get(), - srcPointers, - srcLinesizes, - 0, - avFrame->height, - dstPointers, - dstLinesizes); - } - - return colorConvertedHeight; -} - torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, const FrameDims& outputDims) { enum AVPixelFormat avFrameFormat = static_cast(avFrame->format); - FiltersContext filtersContext( + FiltersConfig filtersConfig( avFrame->width, avFrame->height, avFrameFormat, @@ -379,10 +277,10 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( filters_, timeBase_); - if (!filterGraph_ || prevFiltersContext_ != filtersContext) { + if (!filterGraph_ || prevFiltersConfig_ != filtersConfig) { filterGraph_ = - std::make_unique(filtersContext, videoStreamOptions_); - prevFiltersContext_ = std::move(filtersContext); + std::make_unique(filtersConfig, videoStreamOptions_); + prevFiltersConfig_ = std::move(filtersConfig); } return rgbAVFrameToTensor(filterGraph_->convert(avFrame)); } @@ -520,8 +418,8 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding( AVPixelFormat outPixelFormat = codecContext->pix_fmt; // Initialize and cache scaling context if it does not exist - if (!swsContext_) { - swsContext_.reset(sws_getContext( + if (!encodingSwsContext_) { + encodingSwsContext_.reset(sws_getContext( inWidth, inHeight, inPixelFormat, @@ -532,7 +430,8 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding( nullptr, nullptr, nullptr)); - STD_TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); + STD_TORCH_CHECK( + encodingSwsContext_ != nullptr, "Failed to create scaling context"); } UniqueAVFrame avFrame(av_frame_alloc()); @@ -571,7 +470,7 @@ UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrameForEncoding( inputFrame->linesize[2] = inWidth; status = sws_scale( - swsContext_.get(), + encodingSwsContext_.get(), inputFrame->data, inputFrame->linesize, 0, diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index fdbf8ba8a..2ab86d037 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -9,6 +9,7 @@ #include "DeviceInterface.h" #include "FFMPEGCommon.h" #include "FilterGraph.h" +#include "SwScale.h" namespace facebook::torchcodec { @@ -61,11 +62,6 @@ class CpuDeviceInterface : public DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor); - int convertAVFrameToTensorUsingSwScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor, - const FrameDims& outputDims); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, const FrameDims& outputDims); @@ -85,32 +81,24 @@ class CpuDeviceInterface : public DeviceInterface { // resolutions. std::optional resizedOutputDims_; - // Color-conversion objects. Only one of filterGraph_ and swsContext_ should + // Color-conversion objects. Only one of filterGraph_ and swScale_ should // be non-null. Which one we use is determined dynamically in // getColorConversionLibrary() each time we decode a frame. // - // Creating both filterGraph_ and swsContext_ is relatively expensive, so we - // reuse them across frames. However, it is possbile that subsequent frames + // Creating both filterGraph_ and swScale_ is relatively expensive, so we + // reuse them across frames. However, it is possible that subsequent frames // are different enough (change in dimensions) that we can't reuse the color - // conversion object. We store the relevant frame context from the frame used + // conversion object. We store the relevant frame config from the frame used // to create the object last time. We always compare the current frame's info // against the previous one to determine if we need to recreate the color // conversion object. - // - // TODO: The names of these fields is confusing, as the actual color - // conversion object for Sws has "context" in the name, and we use - // "context" for the structs we store to know if we need to recreate a - // color conversion object. We should clean that up. std::unique_ptr filterGraph_; - FiltersContext prevFiltersContext_; - UniqueSwsContext swsContext_; - SwsFrameContext prevSwsFrameContext_; - - // Cached swscale context for resizing in RGB24 space (used in double swscale - // path). Like the color conversion context above, we cache this to avoid - // recreating it for every frame. - UniqueSwsContext resizeSwsContext_; - SwsFrameContext prevResizeSwsFrameContext_; + FiltersConfig prevFiltersConfig_; + std::unique_ptr swScale_; + + // Cached swscale context for encoding (tensor -> AVFrame pixel format + // conversion). + UniqueSwsContext encodingSwsContext_; // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline // of what FFmpeg calls "filters" to apply to decoded frames before returning diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 031f43009..543a627a0 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -197,7 +197,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( enum AVPixelFormat frameFormat = static_cast(avFrame->format); - auto newContext = std::make_unique( + auto newConfig = std::make_unique( avFrame->width, avFrame->height, frameFormat, @@ -209,22 +209,22 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( timeBase_, av_buffer_ref(avFrame->hw_frames_ctx)); - if (!nv12Conversion_ || *nv12ConversionContext_ != *newContext) { + if (!nv12Conversion_ || *nv12ConversionConfig_ != *newConfig) { nv12Conversion_ = - std::make_unique(*newContext, videoStreamOptions_); - nv12ConversionContext_ = std::move(newContext); + std::make_unique(*newConfig, videoStreamOptions_); + nv12ConversionConfig_ = std::move(newConfig); } auto filteredAVFrame = nv12Conversion_->convert(avFrame); // If this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. STD_TORCH_CHECK( - (filteredAVFrame->width == nv12ConversionContext_->outputWidth) && - (filteredAVFrame->height == nv12ConversionContext_->outputHeight), + (filteredAVFrame->width == nv12ConversionConfig_->outputWidth) && + (filteredAVFrame->height == nv12ConversionConfig_->outputHeight), "Expected frame from filter graph of ", - nv12ConversionContext_->outputWidth, + nv12ConversionConfig_->outputWidth, "x", - nv12ConversionContext_->outputHeight, + nv12ConversionConfig_->outputHeight, ", got ", filteredAVFrame->width, "x", diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index daa371b01..dfd5b2591 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -69,7 +69,7 @@ class CudaDeviceInterface : public DeviceInterface { // This filtergraph instance is only used for NV12 format conversion in // maybeConvertAVFrameToNV12(). - std::unique_ptr nv12ConversionContext_; + std::unique_ptr nv12ConversionConfig_; std::unique_ptr nv12Conversion_; bool usingCPUFallback_ = false; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index bde9c463e..e7b3efa35 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -681,39 +681,41 @@ std::optional getRotationFromStream(const AVStream* avStream) { return rotation; } -SwsFrameContext::SwsFrameContext( +SwsConfig::SwsConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, + AVColorSpace inputColorspace, int outputWidth, int outputHeight) : inputWidth(inputWidth), inputHeight(inputHeight), inputFormat(inputFormat), + inputColorspace(inputColorspace), outputWidth(outputWidth), outputHeight(outputHeight) {} -bool SwsFrameContext::operator==(const SwsFrameContext& other) const { +bool SwsConfig::operator==(const SwsConfig& other) const { return inputWidth == other.inputWidth && inputHeight == other.inputHeight && - inputFormat == other.inputFormat && outputWidth == other.outputWidth && - outputHeight == other.outputHeight; + inputFormat == other.inputFormat && + inputColorspace == other.inputColorspace && + outputWidth == other.outputWidth && outputHeight == other.outputHeight; } -bool SwsFrameContext::operator!=(const SwsFrameContext& other) const { +bool SwsConfig::operator!=(const SwsConfig& other) const { return !(*this == other); } UniqueSwsContext createSwsContext( - const SwsFrameContext& swsFrameContext, - AVColorSpace colorspace, + const SwsConfig& swsConfig, AVPixelFormat outputFormat, int swsFlags) { SwsContext* swsContext = sws_getContext( - swsFrameContext.inputWidth, - swsFrameContext.inputHeight, - swsFrameContext.inputFormat, - swsFrameContext.outputWidth, - swsFrameContext.outputHeight, + swsConfig.inputWidth, + swsConfig.inputHeight, + swsConfig.inputFormat, + swsConfig.outputWidth, + swsConfig.outputHeight, outputFormat, swsFlags, nullptr, @@ -735,7 +737,7 @@ UniqueSwsContext createSwsContext( &saturation); STD_TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); - const int* colorspaceTable = sws_getCoefficients(colorspace); + const int* colorspaceTable = sws_getCoefficients(swsConfig.inputColorspace); ret = sws_setColorspaceDetails( swsContext, colorspaceTable, diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 33113aba7..77738c3c6 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -291,29 +291,30 @@ AVFilterContext* createAVFilterContextWithOptions( const AVFilter* buffer, const enum AVPixelFormat outputFormat); -struct SwsFrameContext { +struct SwsConfig { int inputWidth = 0; int inputHeight = 0; AVPixelFormat inputFormat = AV_PIX_FMT_NONE; + AVColorSpace inputColorspace = AVCOL_SPC_UNSPECIFIED; int outputWidth = 0; int outputHeight = 0; - SwsFrameContext() = default; - SwsFrameContext( + SwsConfig() = default; + SwsConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, + AVColorSpace inputColorspace, int outputWidth, int outputHeight); - bool operator==(const SwsFrameContext& other) const; - bool operator!=(const SwsFrameContext& other) const; + bool operator==(const SwsConfig& other) const; + bool operator!=(const SwsConfig& other) const; }; // Utility functions for swscale context management UniqueSwsContext createSwsContext( - const SwsFrameContext& swsFrameContext, - AVColorSpace colorspace, + const SwsConfig& swsConfig, AVPixelFormat outputFormat, int swsFlags); diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index 94a6eaabe..904886908 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -15,7 +15,7 @@ extern "C" { namespace facebook::torchcodec { -FiltersContext::FiltersContext( +FiltersConfig::FiltersConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, @@ -41,7 +41,7 @@ bool operator==(const AVRational& lhs, const AVRational& rhs) { return lhs.num == rhs.num && lhs.den == rhs.den; } -bool FiltersContext::operator==(const FiltersContext& other) const { +bool FiltersConfig::operator==(const FiltersConfig& other) const { return inputWidth == other.inputWidth && inputHeight == other.inputHeight && inputFormat == other.inputFormat && outputWidth == other.outputWidth && outputHeight == other.outputHeight && @@ -50,12 +50,12 @@ bool FiltersContext::operator==(const FiltersContext& other) const { hwFramesCtx.get() == other.hwFramesCtx.get(); } -bool FiltersContext::operator!=(const FiltersContext& other) const { +bool FiltersConfig::operator!=(const FiltersConfig& other) const { return !(*this == other); } FilterGraph::FilterGraph( - const FiltersContext& filtersContext, + const FiltersConfig& filtersConfig, const VideoStreamOptions& videoStreamOptions) { filterGraph_.reset(avfilter_graph_alloc()); STD_TORCH_CHECK( @@ -70,13 +70,13 @@ FilterGraph::FilterGraph( UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc()); STD_TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); - srcParams->format = filtersContext.inputFormat; - srcParams->width = filtersContext.inputWidth; - srcParams->height = filtersContext.inputHeight; - srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio; - srcParams->time_base = filtersContext.timeBase; - if (filtersContext.hwFramesCtx) { - srcParams->hw_frames_ctx = av_buffer_ref(filtersContext.hwFramesCtx.get()); + srcParams->format = filtersConfig.inputFormat; + srcParams->width = filtersConfig.inputWidth; + srcParams->height = filtersConfig.inputHeight; + srcParams->sample_aspect_ratio = filtersConfig.inputAspectRatio; + srcParams->time_base = filtersConfig.timeBase; + if (filtersConfig.hwFramesCtx) { + srcParams->hw_frames_ctx = av_buffer_ref(filtersConfig.hwFramesCtx.get()); } sourceContext_ = @@ -100,7 +100,7 @@ FilterGraph::FilterGraph( STD_TORCH_CHECK(bufferSink != nullptr, "Failed to get buffersink filter."); sinkContext_ = createAVFilterContextWithOptions( - filterGraph_.get(), bufferSink, filtersContext.outputFormat); + filterGraph_.get(), bufferSink, filtersConfig.outputFormat); STD_TORCH_CHECK( sinkContext_ != nullptr, "Failed to create and configure buffersink"); @@ -124,7 +124,7 @@ FilterGraph::FilterGraph( AVFilterInOut* inputsTmp = inputs.release(); status = avfilter_graph_parse_ptr( filterGraph_.get(), - filtersContext.filtergraphStr.c_str(), + filtersConfig.filtergraphStr.c_str(), &inputsTmp, &outputsTmp, nullptr); @@ -134,7 +134,7 @@ FilterGraph::FilterGraph( status >= 0, "Failed to parse filter description: ", getFFMPEGErrorStringFromErrorCode(status), - ", provided filters: " + filtersContext.filtergraphStr); + ", provided filters: " + filtersConfig.filtergraphStr); // Check filtergraph validity and configure links and formats. status = avfilter_graph_config(filterGraph_.get(), nullptr); @@ -142,7 +142,7 @@ FilterGraph::FilterGraph( status >= 0, "Failed to configure filter graph: ", getFFMPEGErrorStringFromErrorCode(status), - ", provided filters: " + filtersContext.filtergraphStr); + ", provided filters: " + filtersConfig.filtergraphStr); } UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) { diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h index 12475fe47..4e5257d79 100644 --- a/src/torchcodec/_core/FilterGraph.h +++ b/src/torchcodec/_core/FilterGraph.h @@ -11,7 +11,7 @@ namespace facebook::torchcodec { -struct FiltersContext { +struct FiltersConfig { int inputWidth = 0; int inputHeight = 0; AVPixelFormat inputFormat = AV_PIX_FMT_NONE; @@ -23,10 +23,10 @@ struct FiltersContext { AVRational timeBase = {0, 0}; UniqueAVBufferRef hwFramesCtx; - FiltersContext() = default; - FiltersContext(FiltersContext&&) = default; - FiltersContext& operator=(FiltersContext&&) = default; - FiltersContext( + FiltersConfig() = default; + FiltersConfig(FiltersConfig&&) = default; + FiltersConfig& operator=(FiltersConfig&&) = default; + FiltersConfig( int inputWidth, int inputHeight, AVPixelFormat inputFormat, @@ -38,14 +38,14 @@ struct FiltersContext { AVRational timeBase, AVBufferRef* hwFramesCtx = nullptr); - bool operator==(const FiltersContext&) const; - bool operator!=(const FiltersContext&) const; + bool operator==(const FiltersConfig&) const; + bool operator!=(const FiltersConfig&) const; }; class FilterGraph { public: FilterGraph( - const FiltersContext& filtersContext, + const FiltersConfig& filtersConfig, const VideoStreamOptions& videoStreamOptions); UniqueAVFrame convert(const UniqueAVFrame& avFrame); diff --git a/src/torchcodec/_core/SwScale.cpp b/src/torchcodec/_core/SwScale.cpp new file mode 100644 index 000000000..9b85680c2 --- /dev/null +++ b/src/torchcodec/_core/SwScale.cpp @@ -0,0 +1,115 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "SwScale.h" +#include "Frame.h" + +namespace facebook::torchcodec { + +SwScale::SwScale(const SwsConfig& config, int swsFlags) + : config_(config), swsFlags_(swsFlags) { + needsResize_ = + (config_.inputHeight != config_.outputHeight || + config_.inputWidth != config_.outputWidth); + + // Create color conversion context (input format -> RGB24). + // Color conversion always outputs at the input resolution. + // When no resize is needed, input and output resolutions are the same. + SwsConfig colorConversionFrameConfig( + config_.inputWidth, + config_.inputHeight, + config_.inputFormat, + config_.inputColorspace, + config_.inputWidth, + config_.inputHeight); + + colorConversionSwsContext_ = createSwsContext( + colorConversionFrameConfig, + // See [Transform and Format Conversion Order] for more on the output + // pixel format. + /*outputFormat=*/AV_PIX_FMT_RGB24, + // No flags for color conversion. When resizing is needed, we use a + // separate swscale context with the appropriate resize flags. + /*swsFlags=*/0); + + // Create resize context if needed (RGB24 at input resolution -> RGB24 at + // output resolution). + if (needsResize_) { + SwsConfig resizeFrameConfig( + config_.inputWidth, + config_.inputHeight, + AV_PIX_FMT_RGB24, + AVCOL_SPC_RGB, + config_.outputWidth, + config_.outputHeight); + + resizeSwsContext_ = createSwsContext( + resizeFrameConfig, + /*outputFormat=*/AV_PIX_FMT_RGB24, + /*swsFlags=*/swsFlags_); + } +} + +int SwScale::convert( + const UniqueAVFrame& avFrame, + torch::Tensor& outputTensor) { + // When resizing is needed, we do sws_scale twice: first convert to RGB24 at + // original resolution, then resize in RGB24 space. This ensures transforms + // happen in the output color space (RGB24) rather than the input color space + // (YUV). + // + // When no resize is needed, we do color conversion directly into the output + // tensor. + torch::Tensor colorConvertedTensor = needsResize_ + ? allocateEmptyHWCTensor( + FrameDims(config_.inputHeight, config_.inputWidth), kStableCPU) + : outputTensor; + + uint8_t* colorConvertedPointers[4] = { + colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; + int colorConvertedWidth = static_cast(colorConvertedTensor.sizes()[1]); + int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0}; + + int colorConvertedHeight = sws_scale( + colorConversionSwsContext_.get(), + avFrame->data, + avFrame->linesize, + 0, + avFrame->height, + colorConvertedPointers, + colorConvertedLinesizes); + + STD_TORCH_CHECK( + colorConvertedHeight == avFrame->height, + "Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ", + colorConvertedHeight, + " != ", + avFrame->height); + + if (needsResize_) { + uint8_t* srcPointers[4] = { + colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; + int srcLinesizes[4] = {config_.inputWidth * 3, 0, 0, 0}; + + uint8_t* dstPointers[4] = { + outputTensor.data_ptr(), nullptr, nullptr, nullptr}; + int expectedOutputWidth = static_cast(outputTensor.sizes()[1]); + int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; + + colorConvertedHeight = sws_scale( + resizeSwsContext_.get(), + srcPointers, + srcLinesizes, + 0, + config_.inputHeight, + dstPointers, + dstLinesizes); + } + + return colorConvertedHeight; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SwScale.h b/src/torchcodec/_core/SwScale.h new file mode 100644 index 000000000..24358f787 --- /dev/null +++ b/src/torchcodec/_core/SwScale.h @@ -0,0 +1,48 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include "FFMPEGCommon.h" + +namespace facebook::torchcodec { + +struct FrameDims; + +// SwScale uses a double swscale path: +// 1. Color conversion (e.g., YUV -> RGB24) at the original frame resolution +// 2. Resize in RGB24 space (if resizing is needed) +// +// This approach ensures that transforms happen in the output color space +// (RGB24) rather than the input color space (YUV). +// +// The caller is responsible for caching SwScale instances and recreating them +// when the context changes, similar to how FilterGraph is managed. +class SwScale { + public: + SwScale(const SwsConfig& config, int swsFlags = SWS_BILINEAR); + + int convert(const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); + + const SwsConfig& getConfig() const { + return config_; + } + + private: + SwsConfig config_; + int swsFlags_; + bool needsResize_; + + // Color conversion context (input format -> RGB24 at original resolution). + UniqueSwsContext colorConversionSwsContext_; + + // Resize context (RGB24 -> RGB24 at output resolution). + // May be null if no resize is needed. + UniqueSwsContext resizeSwsContext_; +}; + +} // namespace facebook::torchcodec