diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 531da6a2a..9a8c74f8a 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -635,13 +635,17 @@ void tryToValidateCodecOption( } void sortCodecOptions( + const AVFormatContext* avFormatContext, const std::map& extraOptions, UniqueAVDictionary& codecDict, UniqueAVDictionary& formatDict) { // Accepts a map of options as input, then sorts them into codec options and // format options. The sorted options are returned into two separate dicts. const AVClass* formatClass = avformat_get_class(); + const AVClass* muxerClass = + avFormatContext->oformat ? avFormatContext->oformat->priv_class : nullptr; for (const auto& [key, value] : extraOptions) { + // Check if option is generic format option const AVOption* fmtOpt = av_opt_find2( &formatClass, key.c_str(), @@ -649,10 +653,24 @@ void sortCodecOptions( 0, AV_OPT_SEARCH_CHILDREN | AV_OPT_SEARCH_FAKE_OBJ, nullptr); - if (fmtOpt) { + // Check if option is muxer-specific option + // (Returned from `ffmpeg -h muxer=mp4`) + const AVOption* muxerOpt = nullptr; + if (muxerClass) { + muxerOpt = av_opt_find2( + &muxerClass, + key.c_str(), + nullptr, + 0, + AV_OPT_SEARCH_FAKE_OBJ, + nullptr); + } + if (fmtOpt || muxerOpt) { + // Pass container-format options to formatDict to be used in + // avformat_write_header av_dict_set(formatDict.getAddress(), key.c_str(), value.c_str(), 0); } else { - // Default to codec option (includes AVCodecContext + encoder-private) + // By default, pass as codec option to be used in avcodec_open2 av_dict_set(codecDict.getAddress(), key.c_str(), value.c_str(), 0); } } @@ -834,6 +852,7 @@ void VideoEncoder::initializeEncoder( tryToValidateCodecOption(*avCodec, key.c_str(), value); } sortCodecOptions( + avFormatContext_.get(), videoStreamOptions.extraOptions.value(), avCodecOptions, avFormatOptions_); @@ -913,6 +932,15 @@ void VideoEncoder::encode() { flushBuffers(); status = av_write_trailer(avFormatContext_.get()); + // av_write_trailer returns mfra atom size (positive) for fragmented + // containers, which we'd misinterpret as an error, since all FFmpeg errors + // are negative (see AVERROR definition: + // http://ffmpeg.org/doxygen/8.0/error_8h_source.html) So we replace positive + // values with AVSUCCESS. See: + // https://github.com/FFmpeg/FFmpeg/blob/n8.0/libavformat/movenc.c#L8666 + if (status > 0) { + status = AVSUCCESS; + } STD_TORCH_CHECK( status == AVSUCCESS, "Error in av_write_trailer: ", diff --git a/test/test_encoders.py b/test/test_encoders.py index 3a8c488e1..fbb080998 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1509,3 +1509,51 @@ def test_nvenc_against_ffmpeg_cli( assert color_range == encoder_metadata["color_range"] if color_space is not None: assert color_space == encoder_metadata["color_space"] + + @pytest.mark.skipif( + ffmpeg_major_version == 4, + reason="On FFmpeg 4 hitting a truncated packet results in AVERROR_INVALIDDATA, which torchcodec does not handle.", + ) + @pytest.mark.parametrize("format", ["mp4", "mov"]) + @pytest.mark.parametrize( + "extra_options", + [ + # frag_keyframe with empty_moov (new fragment every keyframe) + {"movflags": "+frag_keyframe+empty_moov"}, + # frag_duration creates fragments based on duration (in microseconds) + {"movflags": "+empty_moov", "frag_duration": "1000000"}, + ], + ) + def test_fragmented_mp4( + self, + tmp_path, + extra_options, + format, + ): + # Test that VideoEncoder can write fragmented files using movflags. + # Fragmented files store metadata interleaved with data rather than + # all at the end, making them decodable even if writing is interrupted. + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + encoded_path = str(tmp_path / f"fragmented_output.{format}") + encoder.to_file(dest=encoded_path, extra_options=extra_options) + + # Decode the file to get reference frames + reference_decoder = VideoDecoder(encoded_path) + reference_frames = [reference_decoder.get_frame_at(i) for i in range(10)] + + # Truncate the file to simulate interrupted write + with open(encoded_path, "rb") as f: + full_content = f.read() + truncated_size = int(len(full_content) * 0.5) + with open(encoded_path, "wb") as f: + f.write(full_content[:truncated_size]) + + # Decode the truncated file and verify first 10 frames match reference + truncated_decoder = VideoDecoder(encoded_path) + assert len(truncated_decoder) >= 10 + for i in range(10): + truncated_frame = truncated_decoder.get_frame_at(i) + torch.testing.assert_close( + truncated_frame.data, reference_frames[i].data, atol=0, rtol=0 + )