Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ MLX was developed with contributions from the following individuals:
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
- Siddhartha Tuladhar: Added `reflect` and `symmetric` padding modes.

<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
Expand Down
192 changes: 192 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,194 @@ array edge_pad(
return padded;
}

array reflect_pad(
const array& a,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Shape& high_pad_size,
const Shape& out_shape,
StreamOrDevice s /* = {}*/) {
array out = zeros(out_shape, a.dtype(), s);
auto stops = a.shape();
for (int i = 0; i < stops.size(); i++) {
stops[i] += low_pad_size[i];
}
array padded = slice_update(out, a, low_pad_size, stops, s);

for (int axis = 0; axis < a.ndim(); axis++) {
if (low_pad_size[axis] > 0) {
int n = a.shape(axis);
Shape starts(a.ndim(), 0);
Shape stops = padded.shape();
Shape strides(a.ndim(), 1);
strides[axis] = -1;

starts[axis] = low_pad_size[axis] + 1;
stops[axis] = low_pad_size[axis] + n;
array forward = slice(padded, starts, stops, s);

starts[axis] = low_pad_size[axis] + n - 2;
stops[axis] = low_pad_size[axis] - 1;
array backward = slice(padded, starts, stops, strides, s);

// build bounce pattern cycle
array cycle = concatenate({forward, backward}, axis, s);
int cycle_len = cycle.shape(axis);
int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1;
std::vector<int> reps(a.ndim(), 1);
reps[axis] = reps_needed;
array tiled = tile(cycle, reps, s);

Shape slice_starts(a.ndim(), 0);
Shape slice_stops = tiled.shape();
slice_stops[axis] = low_pad_size[axis];
array padding = slice(tiled, slice_starts, slice_stops, s);

// reverse padding for left side placement
Shape rev_strides(a.ndim(), 1);
rev_strides[axis] = -1;
Shape rev_starts(a.ndim(), 0);
Shape rev_stops = padding.shape();
rev_starts[axis] = padding.shape(axis) - 1;
rev_stops[axis] = -padding.shape(axis) - 1;
padding = slice(padding, rev_starts, rev_stops, rev_strides, s);

starts[axis] = 0;
stops[axis] = low_pad_size[axis];
padded = slice_update(padded, padding, starts, stops, s);
}

if (high_pad_size[axis] > 0) {
int n = a.shape(axis);
Shape starts(a.ndim(), 0);
Shape stops = padded.shape();
Shape strides(a.ndim(), 1);
strides[axis] = -1;
int orig_end = low_pad_size[axis] + n;

starts[axis] = orig_end - n + 1;
stops[axis] = orig_end;
array forward = slice(padded, starts, stops, s);

starts[axis] = orig_end - 2;
stops[axis] = -(padded.shape(axis) + 1);
array backward = slice(padded, starts, stops, strides, s);

// build bounce pattern cycle
array cycle = concatenate({backward, forward}, axis, s);
int cycle_len = cycle.shape(axis);
int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1;
std::vector<int> reps(a.ndim(), 1);
reps[axis] = reps_needed;
array tiled = tile(cycle, reps, s);

Shape slice_starts(a.ndim(), 0);
Shape slice_stops = tiled.shape();
slice_stops[axis] = high_pad_size[axis];
array padding = slice(tiled, slice_starts, slice_stops, s);

starts[axis] = low_pad_size[axis] + n;
stops[axis] = padded.shape(axis);
padded = slice_update(padded, padding, starts, stops, s);
}
}
return padded;
}

array symmetric_pad(
const array& a,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Shape& high_pad_size,
const Shape& out_shape,
StreamOrDevice s /* = {}*/) {
array out = zeros(out_shape, a.dtype(), s);
auto stops = a.shape();
for (int i = 0; i < stops.size(); i++) {
stops[i] += low_pad_size[i];
}
array padded = slice_update(out, a, low_pad_size, stops, s);

for (int axis = 0; axis < a.ndim(); axis++) {
if (low_pad_size[axis] > 0) {
int n = a.shape(axis);
Shape starts(a.ndim(), 0);
Shape stops = padded.shape();
Shape strides(a.ndim(), 1);
strides[axis] = -1;

starts[axis] = low_pad_size[axis];
stops[axis] = low_pad_size[axis] + n;
array forward = slice(padded, starts, stops, s);

starts[axis] = low_pad_size[axis] + n - 1;
stops[axis] = (padded.shape(axis) + 1);
array backward = slice(padded, starts, stops, strides, s);

// build bounce pattern cycle
array cycle = concatenate({forward, backward}, axis, s);
int cycle_len = cycle.shape(axis);
int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1;
std::vector<int> reps(a.ndim(), 1);
reps[axis] = reps_needed;
array tiled = tile(cycle, reps, s);

Shape slice_starts(a.ndim(), 0);
Shape slice_stops = tiled.shape();
slice_stops[axis] = low_pad_size[axis];
array padding = slice(tiled, slice_starts, slice_stops, s);

// reverse padding for left side placement
Shape rev_strides(a.ndim(), 1);
rev_strides[axis] = -1;
Shape rev_starts(a.ndim(), 0);
Shape rev_stops = padding.shape();
rev_starts[axis] = padding.shape(axis) - 1;
rev_stops[axis] = -padding.shape(axis) - 1;
padding = slice(padding, rev_starts, rev_stops, rev_strides, s);

starts[axis] = 0;
stops[axis] = low_pad_size[axis];
padded = slice_update(padded, padding, starts, stops, s);
}

if (high_pad_size[axis] > 0) {
int n = a.shape(axis);
Shape starts(a.ndim(), 0);
Shape stops = padded.shape();
Shape strides(a.ndim(), 1);
strides[axis] = -1;
int orig_end = low_pad_size[axis] + n;

starts[axis] = orig_end - n;
stops[axis] = orig_end;
array forward = slice(padded, starts, stops, s);

starts[axis] = orig_end - 1;
stops[axis] = -(padded.shape(axis) + 1);
array backward = slice(padded, starts, stops, strides, s);

// build bounce pattern cycle
array cycle = concatenate({backward, forward}, axis, s);
int cycle_len = cycle.shape(axis);
int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1;
std::vector<int> reps(a.ndim(), 1);
reps[axis] = reps_needed;
array tiled = tile(cycle, reps, s);

Shape slice_starts(a.ndim(), 0);
Shape slice_stops = tiled.shape();
slice_stops[axis] = high_pad_size[axis];
array padding = slice(tiled, slice_starts, slice_stops, s);

starts[axis] = low_pad_size[axis] + n;
stops[axis] = padded.shape(axis);
padded = slice_update(padded, padding, starts, stops, s);
}
}
return padded;
}

/** Pad an array with a constant value */
array pad(
const array& a,
Expand Down Expand Up @@ -1267,6 +1455,10 @@ array pad(
{a, astype(pad_value, a.dtype(), s)});
} else if (mode == "edge") {
return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
} else if (mode == "reflect") {
return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
} else if (mode == "symmetric") {
return symmetric_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
} else {
std::ostringstream msg;
msg << "Invalid padding mode (" << mode << ") passed to pad";
Expand Down
4 changes: 3 additions & 1 deletion python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3142,7 +3142,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'reflect', 'symmetric'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Pad an array with a constant value

Expand All @@ -3157,6 +3157,8 @@ void init_ops(nb::module_& m) {
mode: Padding mode. One of the following strings:
"constant" (default): Pads with a constant value.
"edge": Pads with the edge values of array.
"reflect": Pads with the reflection of the array mirrored along the edge, excluding the edge value.
"symmetric": Pads with the reflection of the array mirrored along the edge, including the edge value.
constant_value (array or scalar, optional): Optional constant value
to pad the edges of the array with.

Expand Down
12 changes: 12 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,18 @@ def test_pad(self):
self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))

b_npy = np.pad(a_npy, pw, mode="reflect")
b_mlx = mx.pad(a_mlx, pw, mode="reflect")

self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))

b_npy = np.pad(a_npy, pw, mode="symmetric")
b_mlx = mx.pad(a_mlx, pw, mode="symmetric")

self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))

a = mx.zeros((1, 1, 1))
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))
Expand Down
79 changes: 79 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2743,6 +2743,85 @@ TEST_CASE("test pad") {
CHECK(array_equal(padded_x, expected).item<bool>());
}

TEST_CASE("test pad reflect") {
auto x1d = array({1.0f, 2.0f, 3.0f});
auto padded_1d = pad(x1d, {{2, 2}}, array(0), "reflect");
auto expected_1d = array({3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f});
CHECK(array_equal(padded_1d, expected_1d).item<bool>());

auto x2d =
array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3});
auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "reflect");
auto expected_2d = array(
{6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 2.0f, 3.0f,
2.0f, 1.0f, 6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f, 9.0f, 8.0f, 7.0f,
8.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f},
{5, 7});
CHECK(array_equal(padded_2d, expected_2d).item<bool>());

auto x_small = array({1.0f, 2.0f, 3.0f});
auto padded_large = pad(x_small, {{5, 5}}, array(0), "reflect");
auto expected_large = array(
{2.0f,
1.0f,
2.0f,
3.0f,
2.0f,
1.0f,
2.0f,
3.0f,
2.0f,
1.0f,
2.0f,
3.0f,
2.0f});
CHECK(array_equal(padded_large, expected_large).item<bool>());

auto x_min = array({1.0f, 2.0f});
auto padded_min = pad(x_min, {{1, 1}}, array(0), "reflect");
auto expected_min = array({2.0f, 1.0f, 2.0f, 1.0f});
CHECK(array_equal(padded_min, expected_min).item<bool>());

auto x_asym = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
auto padded_asym = pad(x_asym, {{0, 2}, {1, 0}}, array(0), "reflect");
CHECK_EQ(padded_asym.shape(), Shape{4, 3});
}

TEST_CASE("test pad symmetric") {
auto x1d = array({1.0f, 2.0f, 3.0f});
auto padded_1d = pad(x1d, {{2, 2}}, array(0), "symmetric");
auto expected_1d = array({2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f});
CHECK(array_equal(padded_1d, expected_1d).item<bool>());

auto x2d =
array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3});
auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "symmetric");
auto expected_2d = array(
{5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f,
3.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f, 5.0f, 4.0f, 4.0f,
5.0f, 6.0f, 6.0f, 5.0f, 8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f},
{5, 7});
CHECK(array_equal(padded_2d, expected_2d).item<bool>());

auto x_small = array({1.0f, 2.0f, 3.0f});
auto padded_large = pad(x_small, {{5, 5}}, array(0), "symmetric");
auto expected_large = array(
{3.0f,
2.0f,
1.0f,
1.0f,
2.0f,
3.0f,
3.0f,
2.0f,
1.0f,
1.0f,
2.0f,
3.0f,
3.0f});
CHECK(array_equal(padded_large, expected_large).item<bool>());
}

TEST_CASE("test power") {
CHECK_EQ(power(array(1), array(2)).item<int>(), 1);
CHECK_EQ((power(array(-1), array(2))).item<int>(), 1);
Expand Down