From 294eec7fff2dbd3d63466e6eb402d0515cfcc11d Mon Sep 17 00:00:00 2001 From: Harsh Sutaria Date: Tue, 2 Dec 2025 14:02:56 -0500 Subject: [PATCH] Fix memory leak in transforms (#2841) --- mlx/transforms.cpp | 20 ++++++++++++++++++-- python/tests/test_autograd.py | 22 ++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 4967c50a8b..cdc488e77c 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -337,7 +337,15 @@ std::pair, std::vector> vjp( for (auto& p : primals) { auto s = p.has_primitive() ? p.primitive().stream() : default_stream(default_device()); - primals_.push_back(copy(p, s)); // Does not do a deep copy + array source = p; + if (!p.is_tracer()) { + while (source.has_primitive() && + typeid(source.primitive()) == typeid(Copy) && + !source.inputs().empty()) { + source = source.inputs()[0]; + } + } + primals_.push_back(copy(source, s)); // Does not do a deep copy primals_.back().set_tracer(true); } @@ -545,7 +553,15 @@ std::pair, std::vector> jvp( for (auto& p : primals) { auto s = p.has_primitive() ? p.primitive().stream() : default_stream(default_device()); - primals_.push_back(copy(p, s)); // Does not do a deep copy + array source = p; + if (!p.is_tracer()) { + while (source.has_primitive() && + typeid(source.primitive()) == typeid(Copy) && + !source.inputs().empty()) { + source = source.inputs()[0]; + } + } + primals_.push_back(copy(source, s)); // Does not do a deep copy primals_.back().set_tracer(true); } auto outputs = fun(primals_); diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 38bb6089d3..50d4887669 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -798,6 +798,28 @@ def loss_fn(model): grad_fn(model) self.assertEqual(model[1].item(), 2.0) + def test_grad_with_container_reuse(self): + container = [mx.array(1.0)] + + def fn(p, x): + container[0] = p + return x.sum() + + x = mx.ones(shape=(128,)) + grad_fn = mx.grad(fn) + + mx.synchronize() + gc.collect() + mem_pre = mx.get_active_memory() + + for _ in range(20): + mx.eval(grad_fn(container[0], x)) + gc.collect() + + mx.synchronize() + mem_post = mx.get_active_memory() + self.assertLess(mem_post - mem_pre, 1024 * 1024) + if __name__ == "__main__": mlx_tests.MLXTestRunner()