From bd5417f587c312491e0bb42d48a8ce7768d067d4 Mon Sep 17 00:00:00 2001 From: Sampurna Tuladhar Date: Wed, 3 Dec 2025 00:12:32 -0800 Subject: [PATCH 1/4] feat: working reflect pad --- mlx/ops.cpp | 143 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index fbe6799373..319155233b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1220,6 +1220,147 @@ array edge_pad( return padded; } +array reflect_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, + StreamOrDevice s /* = {}*/) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + // Copy over values from the unpadded array + array padded = slice_update(out, a, low_pad_size, stops, s); + for (int axis = 0; axis < a.ndim(); axis++) { + std::cout << "Processing axis=" << axis << " low_pad=" << low_pad_size[axis] + << " high_pad=" << high_pad_size[axis] << std::endl; + if (low_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + + starts[axis] = low_pad_size[axis] + 1; + stops[axis] = low_pad_size[axis] + n; + array forward = slice(padded, starts, stops, s); + + starts[axis] = low_pad_size[axis] + n - 2; + stops[axis] = low_pad_size[axis] - 1; + array backward = slice(padded, starts, stops, strides, s); + + array cycle = concatenate({forward, backward}, axis, s); + int cycle_len = cycle.shape(axis); // how many rows in cycle (4) + int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = low_pad_size[axis]; + + array padding = slice(tiled, slice_starts, slice_stops, s); + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; + + Shape rev_strides(a.ndim(), 1); + rev_strides[axis] = -1; + Shape rev_starts(a.ndim(), 0); + Shape rev_stops = padding.shape(); + rev_starts[axis] = padding.shape(axis) - 1; // start from last + rev_stops[axis] = -padding.shape(axis) - 1; // go before first + padding = slice(padding, rev_starts, rev_stops, rev_strides, s); + // Update values in the padded array + padded = slice_update(padded, padding, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + starts[axis] = padded.shape(axis) - high_pad_size[axis] - n + 1; + stops[axis] = padded.shape(axis) - high_pad_size[axis]; + // Edge + 1 values + array forward = slice(padded, starts, stops, s); + starts[axis] = padded.shape(axis) - high_pad_size[axis] - 2; + stops[axis] = padded.shape(axis) - high_pad_size[axis] - n - 1; + + array backward = slice(padded, starts, stops, strides, s); + array cycle = concatenate({backward, forward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = high_pad_size[axis]; + + array padding = slice(tiled, slice_starts, slice_stops, s); + + starts[axis] = padded.shape(axis) - high_pad_size[axis]; + stops[axis] = padded.shape(axis); + + // Update values in the padded array + padded = slice_update(padded, padding, starts, stops, s); + } + } + return padded; +} + +// array symmetric_pad( +// const array& a, +// const std::vector& axes, +// const Shape& low_pad_size, +// const Shape& high_pad_size, +// const Shape& out_shape, +// StreamOrDevice s /* = {}*/) { +// array out = zeros(out_shape, a.dtype(), s); +// auto stops = a.shape(); +// for (int i = 0; i < stops.size(); i++) { +// stops[i] += low_pad_size[i]; +// } +// // Copy over values from the unpadded array +// array padded = slice_update(out, a, low_pad_size, stops, s); +// +// for (int axis = 0; axis < a.ndim(); axis++) { +// if (low_pad_size[axis] > 0) { +// Shape starts(a.ndim(), 0); +// starts[axis] = low_pad_size[axis]; +// auto stops = out.shape(); +// stops[axis] = low_pad_size[axis] + 1; +// // Fetch edge values +// array edge_value = slice(padded, starts, stops, s); +// +// starts[axis] = 0; +// stops[axis] = low_pad_size[axis]; +// // Update edge values in the padded array +// padded = slice_update(padded, edge_value, starts, stops, s); +// } +// +// if (high_pad_size[axis] > 0) { +// Shape starts(a.ndim(), 0); +// starts[axis] = -high_pad_size[axis] - 1; +// auto stops = out.shape(); +// stops[axis] = -high_pad_size[axis]; +// array edge_value = slice(padded, starts, stops, s); +// +// starts[axis] = -high_pad_size[axis]; +// stops[axis] = out.shape(axis); +// padded = slice_update(padded, edge_value, starts, stops, s); +// } +// } +// return padded; +// } + /** Pad an array with a constant value */ array pad( const array& a, @@ -1267,6 +1408,8 @@ array pad( {a, astype(pad_value, a.dtype(), s)}); } else if (mode == "edge") { return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "reflect") { + return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); } else { std::ostringstream msg; msg << "Invalid padding mode (" << mode << ") passed to pad"; From 0c847bccffb273e202a329a56fe1ec811cb4bb97 Mon Sep 17 00:00:00 2001 From: Sampurna Tuladhar Date: Wed, 3 Dec 2025 01:22:31 -0800 Subject: [PATCH 2/4] feat: add symmetric pad --- mlx/ops.cpp | 173 +++++++++++++++++++++++++-------------- python/src/ops.cpp | 4 +- python/tests/test_ops.py | 20 ++++- tests/ops_tests.cpp | 79 ++++++++++++++++++ 4 files changed, 212 insertions(+), 64 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 319155233b..ee53f103fb 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1232,11 +1232,9 @@ array reflect_pad( for (int i = 0; i < stops.size(); i++) { stops[i] += low_pad_size[i]; } - // Copy over values from the unpadded array array padded = slice_update(out, a, low_pad_size, stops, s); + for (int axis = 0; axis < a.ndim(); axis++) { - std::cout << "Processing axis=" << axis << " low_pad=" << low_pad_size[axis] - << " high_pad=" << high_pad_size[axis] << std::endl; if (low_pad_size[axis] > 0) { int n = a.shape(axis); Shape starts(a.ndim(), 0); @@ -1252,8 +1250,9 @@ array reflect_pad( stops[axis] = low_pad_size[axis] - 1; array backward = slice(padded, starts, stops, strides, s); + // build bounce pattern cycle array cycle = concatenate({forward, backward}, axis, s); - int cycle_len = cycle.shape(axis); // how many rows in cycle (4) + int cycle_len = cycle.shape(axis); int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1; std::vector reps(a.ndim(), 1); reps[axis] = reps_needed; @@ -1262,20 +1261,113 @@ array reflect_pad( Shape slice_starts(a.ndim(), 0); Shape slice_stops = tiled.shape(); slice_stops[axis] = low_pad_size[axis]; - array padding = slice(tiled, slice_starts, slice_stops, s); + // reverse padding for left side placement + Shape rev_strides(a.ndim(), 1); + rev_strides[axis] = -1; + Shape rev_starts(a.ndim(), 0); + Shape rev_stops = padding.shape(); + rev_starts[axis] = padding.shape(axis) - 1; + rev_stops[axis] = -padding.shape(axis) - 1; + padding = slice(padding, rev_starts, rev_stops, rev_strides, s); + starts[axis] = 0; stops[axis] = low_pad_size[axis]; + padded = slice_update(padded, padding, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + int orig_end = low_pad_size[axis] + n; + + starts[axis] = orig_end - n + 1; + stops[axis] = orig_end; + array forward = slice(padded, starts, stops, s); + + starts[axis] = orig_end - 2; + stops[axis] = -(padded.shape(axis) + 1); + array backward = slice(padded, starts, stops, strides, s); + // build bounce pattern cycle + array cycle = concatenate({backward, forward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = high_pad_size[axis]; + array padding = slice(tiled, slice_starts, slice_stops, s); + + starts[axis] = low_pad_size[axis] + n; + stops[axis] = padded.shape(axis); + padded = slice_update(padded, padding, starts, stops, s); + } + } + return padded; +} + +array symmetric_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, + StreamOrDevice s /* = {}*/) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + array padded = slice_update(out, a, low_pad_size, stops, s); + + for (int axis = 0; axis < a.ndim(); axis++) { + if (low_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + + starts[axis] = low_pad_size[axis]; + stops[axis] = low_pad_size[axis] + n; + array forward = slice(padded, starts, stops, s); + + starts[axis] = low_pad_size[axis] + n - 1; + stops[axis] = (padded.shape(axis) + 1); + array backward = slice(padded, starts, stops, strides, s); + + // build bounce pattern cycle + array cycle = concatenate({forward, backward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = low_pad_size[axis]; + array padding = slice(tiled, slice_starts, slice_stops, s); + + // reverse padding for left side placement Shape rev_strides(a.ndim(), 1); rev_strides[axis] = -1; Shape rev_starts(a.ndim(), 0); Shape rev_stops = padding.shape(); - rev_starts[axis] = padding.shape(axis) - 1; // start from last - rev_stops[axis] = -padding.shape(axis) - 1; // go before first + rev_starts[axis] = padding.shape(axis) - 1; + rev_stops[axis] = -padding.shape(axis) - 1; padding = slice(padding, rev_starts, rev_stops, rev_strides, s); - // Update values in the padded array + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; padded = slice_update(padded, padding, starts, stops, s); } @@ -1285,14 +1377,17 @@ array reflect_pad( Shape stops = padded.shape(); Shape strides(a.ndim(), 1); strides[axis] = -1; - starts[axis] = padded.shape(axis) - high_pad_size[axis] - n + 1; - stops[axis] = padded.shape(axis) - high_pad_size[axis]; - // Edge + 1 values + int orig_end = low_pad_size[axis] + n; + + starts[axis] = orig_end - n; + stops[axis] = orig_end; array forward = slice(padded, starts, stops, s); - starts[axis] = padded.shape(axis) - high_pad_size[axis] - 2; - stops[axis] = padded.shape(axis) - high_pad_size[axis] - n - 1; + starts[axis] = orig_end - 1; + stops[axis] = -(padded.shape(axis) + 1); array backward = slice(padded, starts, stops, strides, s); + + // build bounce pattern cycle array cycle = concatenate({backward, forward}, axis, s); int cycle_len = cycle.shape(axis); int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1; @@ -1303,64 +1398,16 @@ array reflect_pad( Shape slice_starts(a.ndim(), 0); Shape slice_stops = tiled.shape(); slice_stops[axis] = high_pad_size[axis]; - array padding = slice(tiled, slice_starts, slice_stops, s); - starts[axis] = padded.shape(axis) - high_pad_size[axis]; + starts[axis] = low_pad_size[axis] + n; stops[axis] = padded.shape(axis); - - // Update values in the padded array padded = slice_update(padded, padding, starts, stops, s); } } return padded; } -// array symmetric_pad( -// const array& a, -// const std::vector& axes, -// const Shape& low_pad_size, -// const Shape& high_pad_size, -// const Shape& out_shape, -// StreamOrDevice s /* = {}*/) { -// array out = zeros(out_shape, a.dtype(), s); -// auto stops = a.shape(); -// for (int i = 0; i < stops.size(); i++) { -// stops[i] += low_pad_size[i]; -// } -// // Copy over values from the unpadded array -// array padded = slice_update(out, a, low_pad_size, stops, s); -// -// for (int axis = 0; axis < a.ndim(); axis++) { -// if (low_pad_size[axis] > 0) { -// Shape starts(a.ndim(), 0); -// starts[axis] = low_pad_size[axis]; -// auto stops = out.shape(); -// stops[axis] = low_pad_size[axis] + 1; -// // Fetch edge values -// array edge_value = slice(padded, starts, stops, s); -// -// starts[axis] = 0; -// stops[axis] = low_pad_size[axis]; -// // Update edge values in the padded array -// padded = slice_update(padded, edge_value, starts, stops, s); -// } -// -// if (high_pad_size[axis] > 0) { -// Shape starts(a.ndim(), 0); -// starts[axis] = -high_pad_size[axis] - 1; -// auto stops = out.shape(); -// stops[axis] = -high_pad_size[axis]; -// array edge_value = slice(padded, starts, stops, s); -// -// starts[axis] = -high_pad_size[axis]; -// stops[axis] = out.shape(axis); -// padded = slice_update(padded, edge_value, starts, stops, s); -// } -// } -// return padded; -// } - /** Pad an array with a constant value */ array pad( const array& a, @@ -1410,6 +1457,8 @@ array pad( return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); } else if (mode == "reflect") { return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "symmetric") { + return symmetric_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); } else { std::ostringstream msg; msg << "Invalid padding mode (" << mode << ") passed to pad"; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 7c00ad8a1d..f4810a4e97 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3142,7 +3142,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'reflect', 'symmetric'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Pad an array with a constant value @@ -3157,6 +3157,8 @@ void init_ops(nb::module_& m) { mode: Padding mode. One of the following strings: "constant" (default): Pads with a constant value. "edge": Pads with the edge values of array. + "reflect": Pads with the reflection of the array mirrored along the edge, excluding the edge value. + "symmetric": Pads with the reflection of the array mirrored along the edge, including the edge value. constant_value (array or scalar, optional): Optional constant value to pad the edges of the array with. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f00b8047ea..9f5efff4fc 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -5,10 +5,11 @@ import unittest from itertools import permutations, product -import mlx.core as mx import mlx_tests import numpy as np +import mlx.core as mx + def np_wrap_between(x, a): """Wraps `x` between `[-a, a]`.""" @@ -1885,6 +1886,23 @@ def test_pad(self): self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + b_npy = np.pad(a_npy, pw, mode="reflect") + b_mlx = mx.pad(a_mlx, pw, mode="reflect") + + self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + + b_npy = np.pad(a_npy, pw, mode="symmetric") + b_mlx = mx.pad(a_mlx, pw, mode="symmetric") + + if not np.allclose(b_npy, b_mlx, atol=1e-6): + print(f"\nSymmetric test failed for pad_width: {pw}") + print(f"NumPy result (first 20):", b_npy.flat[:20]) + print(f"MLX result (first 20):", np.array(b_mlx).flat[:20]) + + self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + a = mx.zeros((1, 1, 1)) self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3)) self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3)) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 62fd8c5923..26c48cc89d 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2743,6 +2743,85 @@ TEST_CASE("test pad") { CHECK(array_equal(padded_x, expected).item()); } +TEST_CASE("test pad reflect") { + auto x1d = array({1.0f, 2.0f, 3.0f}); + auto padded_1d = pad(x1d, {{2, 2}}, array(0), "reflect"); + auto expected_1d = array({3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f}); + CHECK(array_equal(padded_1d, expected_1d).item()); + + auto x2d = + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3}); + auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "reflect"); + auto expected_2d = array( + {6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 2.0f, 3.0f, + 2.0f, 1.0f, 6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f, 9.0f, 8.0f, 7.0f, + 8.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f}, + {5, 7}); + CHECK(array_equal(padded_2d, expected_2d).item()); + + auto x_small = array({1.0f, 2.0f, 3.0f}); + auto padded_large = pad(x_small, {{5, 5}}, array(0), "reflect"); + auto expected_large = array( + {2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f}); + CHECK(array_equal(padded_large, expected_large).item()); + + auto x_min = array({1.0f, 2.0f}); + auto padded_min = pad(x_min, {{1, 1}}, array(0), "reflect"); + auto expected_min = array({2.0f, 1.0f, 2.0f, 1.0f}); + CHECK(array_equal(padded_min, expected_min).item()); + + auto x_asym = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + auto padded_asym = pad(x_asym, {{0, 2}, {1, 0}}, array(0), "reflect"); + CHECK_EQ(padded_asym.shape(), Shape{4, 3}); +} + +TEST_CASE("test pad symmetric") { + auto x1d = array({1.0f, 2.0f, 3.0f}); + auto padded_1d = pad(x1d, {{2, 2}}, array(0), "symmetric"); + auto expected_1d = array({2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f}); + CHECK(array_equal(padded_1d, expected_1d).item()); + + auto x2d = + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3}); + auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "symmetric"); + auto expected_2d = array( + {5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, + 3.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f, 5.0f, 4.0f, 4.0f, + 5.0f, 6.0f, 6.0f, 5.0f, 8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f}, + {5, 7}); + CHECK(array_equal(padded_2d, expected_2d).item()); + + auto x_small = array({1.0f, 2.0f, 3.0f}); + auto padded_large = pad(x_small, {{5, 5}}, array(0), "symmetric"); + auto expected_large = array( + {3.0f, + 2.0f, + 1.0f, + 1.0f, + 2.0f, + 3.0f, + 3.0f, + 2.0f, + 1.0f, + 1.0f, + 2.0f, + 3.0f, + 3.0f}); + CHECK(array_equal(padded_large, expected_large).item()); +} + TEST_CASE("test power") { CHECK_EQ(power(array(1), array(2)).item(), 1); CHECK_EQ((power(array(-1), array(2))).item(), 1); From 45c37f584825ac44279c9a3eb5ee5a099d3a60a1 Mon Sep 17 00:00:00 2001 From: Sampurna Tuladhar Date: Wed, 3 Dec 2025 01:48:52 -0800 Subject: [PATCH 3/4] syntax: remove debug --- python/tests/test_ops.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9f5efff4fc..a128161b36 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -5,11 +5,10 @@ import unittest from itertools import permutations, product +import mlx.core as mx import mlx_tests import numpy as np -import mlx.core as mx - def np_wrap_between(x, a): """Wraps `x` between `[-a, a]`.""" @@ -1895,11 +1894,6 @@ def test_pad(self): b_npy = np.pad(a_npy, pw, mode="symmetric") b_mlx = mx.pad(a_mlx, pw, mode="symmetric") - if not np.allclose(b_npy, b_mlx, atol=1e-6): - print(f"\nSymmetric test failed for pad_width: {pw}") - print(f"NumPy result (first 20):", b_npy.flat[:20]) - print(f"MLX result (first 20):", np.array(b_mlx).flat[:20]) - self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) From 37b0c3c1f17d3adedebb740b01001dcc9b9b22f1 Mon Sep 17 00:00:00 2001 From: Sampurna Tuladhar Date: Wed, 3 Dec 2025 01:51:36 -0800 Subject: [PATCH 4/4] docs: add ack --- ACKNOWLEDGMENTS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 186908f09c..c5c3b429ef 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -20,6 +20,7 @@ MLX was developed with contributions from the following individuals: - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. +- Siddhartha Tuladhar: Added `reflect` and `symmetric` padding modes.