diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 186908f09c..c5c3b429ef 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -20,6 +20,7 @@ MLX was developed with contributions from the following individuals: - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. +- Siddhartha Tuladhar: Added `reflect` and `symmetric` padding modes. diff --git a/mlx/ops.cpp b/mlx/ops.cpp index fbe6799373..ee53f103fb 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1220,6 +1220,194 @@ array edge_pad( return padded; } +array reflect_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, + StreamOrDevice s /* = {}*/) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + array padded = slice_update(out, a, low_pad_size, stops, s); + + for (int axis = 0; axis < a.ndim(); axis++) { + if (low_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + + starts[axis] = low_pad_size[axis] + 1; + stops[axis] = low_pad_size[axis] + n; + array forward = slice(padded, starts, stops, s); + + starts[axis] = low_pad_size[axis] + n - 2; + stops[axis] = low_pad_size[axis] - 1; + array backward = slice(padded, starts, stops, strides, s); + + // build bounce pattern cycle + array cycle = concatenate({forward, backward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = low_pad_size[axis]; + array padding = slice(tiled, slice_starts, slice_stops, s); + + // reverse padding for left side placement + Shape rev_strides(a.ndim(), 1); + rev_strides[axis] = -1; + Shape rev_starts(a.ndim(), 0); + Shape rev_stops = padding.shape(); + rev_starts[axis] = padding.shape(axis) - 1; + rev_stops[axis] = -padding.shape(axis) - 1; + padding = slice(padding, rev_starts, rev_stops, rev_strides, s); + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; + padded = slice_update(padded, padding, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + int orig_end = low_pad_size[axis] + n; + + starts[axis] = orig_end - n + 1; + stops[axis] = orig_end; + array forward = slice(padded, starts, stops, s); + + starts[axis] = orig_end - 2; + stops[axis] = -(padded.shape(axis) + 1); + array backward = slice(padded, starts, stops, strides, s); + + // build bounce pattern cycle + array cycle = concatenate({backward, forward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = high_pad_size[axis]; + array padding = slice(tiled, slice_starts, slice_stops, s); + + starts[axis] = low_pad_size[axis] + n; + stops[axis] = padded.shape(axis); + padded = slice_update(padded, padding, starts, stops, s); + } + } + return padded; +} + +array symmetric_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, + StreamOrDevice s /* = {}*/) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + array padded = slice_update(out, a, low_pad_size, stops, s); + + for (int axis = 0; axis < a.ndim(); axis++) { + if (low_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + + starts[axis] = low_pad_size[axis]; + stops[axis] = low_pad_size[axis] + n; + array forward = slice(padded, starts, stops, s); + + starts[axis] = low_pad_size[axis] + n - 1; + stops[axis] = (padded.shape(axis) + 1); + array backward = slice(padded, starts, stops, strides, s); + + // build bounce pattern cycle + array cycle = concatenate({forward, backward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = low_pad_size[axis]; + array padding = slice(tiled, slice_starts, slice_stops, s); + + // reverse padding for left side placement + Shape rev_strides(a.ndim(), 1); + rev_strides[axis] = -1; + Shape rev_starts(a.ndim(), 0); + Shape rev_stops = padding.shape(); + rev_starts[axis] = padding.shape(axis) - 1; + rev_stops[axis] = -padding.shape(axis) - 1; + padding = slice(padding, rev_starts, rev_stops, rev_strides, s); + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; + padded = slice_update(padded, padding, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + int orig_end = low_pad_size[axis] + n; + + starts[axis] = orig_end - n; + stops[axis] = orig_end; + array forward = slice(padded, starts, stops, s); + + starts[axis] = orig_end - 1; + stops[axis] = -(padded.shape(axis) + 1); + array backward = slice(padded, starts, stops, strides, s); + + // build bounce pattern cycle + array cycle = concatenate({backward, forward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = high_pad_size[axis]; + array padding = slice(tiled, slice_starts, slice_stops, s); + + starts[axis] = low_pad_size[axis] + n; + stops[axis] = padded.shape(axis); + padded = slice_update(padded, padding, starts, stops, s); + } + } + return padded; +} + /** Pad an array with a constant value */ array pad( const array& a, @@ -1267,6 +1455,10 @@ array pad( {a, astype(pad_value, a.dtype(), s)}); } else if (mode == "edge") { return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "reflect") { + return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "symmetric") { + return symmetric_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); } else { std::ostringstream msg; msg << "Invalid padding mode (" << mode << ") passed to pad"; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 7c00ad8a1d..f4810a4e97 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3142,7 +3142,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'reflect', 'symmetric'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Pad an array with a constant value @@ -3157,6 +3157,8 @@ void init_ops(nb::module_& m) { mode: Padding mode. One of the following strings: "constant" (default): Pads with a constant value. "edge": Pads with the edge values of array. + "reflect": Pads with the reflection of the array mirrored along the edge, excluding the edge value. + "symmetric": Pads with the reflection of the array mirrored along the edge, including the edge value. constant_value (array or scalar, optional): Optional constant value to pad the edges of the array with. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f00b8047ea..a128161b36 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1885,6 +1885,18 @@ def test_pad(self): self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + b_npy = np.pad(a_npy, pw, mode="reflect") + b_mlx = mx.pad(a_mlx, pw, mode="reflect") + + self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + + b_npy = np.pad(a_npy, pw, mode="symmetric") + b_mlx = mx.pad(a_mlx, pw, mode="symmetric") + + self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + a = mx.zeros((1, 1, 1)) self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3)) self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3)) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 62fd8c5923..26c48cc89d 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2743,6 +2743,85 @@ TEST_CASE("test pad") { CHECK(array_equal(padded_x, expected).item()); } +TEST_CASE("test pad reflect") { + auto x1d = array({1.0f, 2.0f, 3.0f}); + auto padded_1d = pad(x1d, {{2, 2}}, array(0), "reflect"); + auto expected_1d = array({3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f}); + CHECK(array_equal(padded_1d, expected_1d).item()); + + auto x2d = + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3}); + auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "reflect"); + auto expected_2d = array( + {6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 2.0f, 3.0f, + 2.0f, 1.0f, 6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f, 9.0f, 8.0f, 7.0f, + 8.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f}, + {5, 7}); + CHECK(array_equal(padded_2d, expected_2d).item()); + + auto x_small = array({1.0f, 2.0f, 3.0f}); + auto padded_large = pad(x_small, {{5, 5}}, array(0), "reflect"); + auto expected_large = array( + {2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f}); + CHECK(array_equal(padded_large, expected_large).item()); + + auto x_min = array({1.0f, 2.0f}); + auto padded_min = pad(x_min, {{1, 1}}, array(0), "reflect"); + auto expected_min = array({2.0f, 1.0f, 2.0f, 1.0f}); + CHECK(array_equal(padded_min, expected_min).item()); + + auto x_asym = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + auto padded_asym = pad(x_asym, {{0, 2}, {1, 0}}, array(0), "reflect"); + CHECK_EQ(padded_asym.shape(), Shape{4, 3}); +} + +TEST_CASE("test pad symmetric") { + auto x1d = array({1.0f, 2.0f, 3.0f}); + auto padded_1d = pad(x1d, {{2, 2}}, array(0), "symmetric"); + auto expected_1d = array({2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f}); + CHECK(array_equal(padded_1d, expected_1d).item()); + + auto x2d = + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3}); + auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "symmetric"); + auto expected_2d = array( + {5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, + 3.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f, 5.0f, 4.0f, 4.0f, + 5.0f, 6.0f, 6.0f, 5.0f, 8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f}, + {5, 7}); + CHECK(array_equal(padded_2d, expected_2d).item()); + + auto x_small = array({1.0f, 2.0f, 3.0f}); + auto padded_large = pad(x_small, {{5, 5}}, array(0), "symmetric"); + auto expected_large = array( + {3.0f, + 2.0f, + 1.0f, + 1.0f, + 2.0f, + 3.0f, + 3.0f, + 2.0f, + 1.0f, + 1.0f, + 2.0f, + 3.0f, + 3.0f}); + CHECK(array_equal(padded_large, expected_large).item()); +} + TEST_CASE("test power") { CHECK_EQ(power(array(1), array(2)).item(), 1); CHECK_EQ((power(array(-1), array(2))).item(), 1);