Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/torchcodec/_core/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ std::string toFilterGraphInterpolation(
switch (mode) {
case ResizeTransform::InterpolationMode::BILINEAR:
return "bilinear";
case ResizeTransform::InterpolationMode::BICUBIC:
return "bicubic";
default:
TORCH_CHECK(
false,
Expand All @@ -29,6 +31,8 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
switch (mode) {
case ResizeTransform::InterpolationMode::BILINEAR:
return SWS_BILINEAR;
case ResizeTransform::InterpolationMode::BICUBIC:
return SWS_BICUBIC;
default:
TORCH_CHECK(
false,
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/Transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Transform {

class ResizeTransform : public Transform {
public:
enum class InterpolationMode { BILINEAR };
enum class InterpolationMode { BILINEAR, BICUBIC };

explicit ResizeTransform(const FrameDims& dims)
: outputDims_(dims), interpolationMode_(InterpolationMode::BILINEAR) {}
Expand Down
33 changes: 27 additions & 6 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,39 @@ int checkedToNonNegativeInt(const std::string& str) {

// Resize transform specs take the form:
//
// "resize, <height>, <width>"
// "resize, <height>, <width>" or "resize, <height>, <width>, <interpolation>"
//
// Where "resize" is the string literal and <height> and <width> are positive
// integers.
// Where "resize" is the string literal, <height> and <width> are positive
// integers, and <interpolation> is an optional string (bilinear, bicubic, or
// lanczos). If <interpolation> is not specified, it defaults to bilinear.
Transform* makeResizeTransform(
const std::vector<std::string>& resizeTransformSpec) {
TORCH_CHECK(
resizeTransformSpec.size() == 3,
"resizeTransformSpec must have 3 elements including its name");
resizeTransformSpec.size() >= 3 && resizeTransformSpec.size() <= 4,
"resizeTransformSpec must have 3 or 4 elements including its name");
int height = checkedToPositiveInt(resizeTransformSpec[1]);
int width = checkedToPositiveInt(resizeTransformSpec[2]);
return new ResizeTransform(FrameDims(height, width));

auto interpolationMode = ResizeTransform::InterpolationMode::BILINEAR;
if (resizeTransformSpec.size() == 4) {
const std::string& modeStr = resizeTransformSpec[3];
// Trim leading/trailing whitespace
auto trimmed = modeStr;
trimmed.erase(0, trimmed.find_first_not_of(" \t"));
trimmed.erase(trimmed.find_last_not_of(" \t") + 1);

if (trimmed == "bilinear") {
interpolationMode = ResizeTransform::InterpolationMode::BILINEAR;
} else if (trimmed == "bicubic") {
interpolationMode = ResizeTransform::InterpolationMode::BICUBIC;
} else {
TORCH_CHECK(
false,
"Unknown interpolation mode: '" + trimmed +
"'. Supported modes: bilinear, bicubic");
}
}
return new ResizeTransform(FrameDims(height, width), interpolationMode);
}

// Crop transform specs take the form:
Expand Down
31 changes: 25 additions & 6 deletions src/torchcodec/transforms/_decoder_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,33 @@ class Resize(DecoderTransform):
"""Resize the decoded frame to a given size.

Complementary TorchVision transform: :class:`~torchvision.transforms.v2.Resize`.
Interpolation is always bilinear. Anti-aliasing is always on.
Anti-aliasing is always on.

Args:
size (Sequence[int]): Desired output size. Must be a sequence of
the form (height, width).
interpolation (str): Interpolation mode. Must be one of "bilinear"
or "bicubic". Default is "bilinear".
"""

def __init__(self, size: Sequence[int]):
_VALID_INTERPOLATIONS = {"bilinear", "bicubic"}

def __init__(self, size: Sequence[int], interpolation: str = "bilinear"):
if len(size) != 2:
raise ValueError(
"Resize transform must have a (height, width) "
f"pair for the size, got {size}."
)
if interpolation not in self._VALID_INTERPOLATIONS:
raise ValueError(
f"Invalid interpolation mode: '{interpolation}'. "
f"Must be one of {self._VALID_INTERPOLATIONS}."
)
self.size = size
self.interpolation = interpolation

def _make_transform_spec(self, input_dims: tuple[int | None, int | None]) -> str:
return f"resize, {self.size[0]}, {self.size[1]}"
return f"resize, {self.size[0]}, {self.size[1]}, {self.interpolation}"

def _get_output_dims(self) -> tuple[int | None, int | None] | None:
return (self.size[0], self.size[1])
Expand All @@ -119,9 +129,16 @@ def _from_torchvision(cls, tv_resize: nn.Module):

assert isinstance(tv_resize, v2.Resize)

if tv_resize.interpolation is not v2.InterpolationMode.BILINEAR:
# Map TorchVision interpolation modes to TorchCodec string modes
interpolation_map = {
v2.InterpolationMode.BILINEAR: "bilinear",
v2.InterpolationMode.BICUBIC: "bicubic",
}

if tv_resize.interpolation not in interpolation_map:
raise ValueError(
"TorchVision Resize transform must use bilinear interpolation."
f"TorchVision Resize transform must use bilinear or bicubic "
f"interpolation. Got: {tv_resize.interpolation}"
)
if tv_resize.antialias is False:
raise ValueError(
Expand All @@ -134,7 +151,9 @@ def _from_torchvision(cls, tv_resize: nn.Module):
"TorchVision Resize transform must have a (height, width) "
f"pair for the size, got {tv_resize.size}."
)
return cls(size=tv_resize.size)

interpolation = interpolation_map[tv_resize.interpolation]
return cls(size=tv_resize.size, interpolation=interpolation)


class CenterCrop(DecoderTransform):
Expand Down
Loading