From 66fcb9fe94125fcf61a2f0fccc255220cc98c669 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 16:02:31 -0700 Subject: [PATCH 01/30] array: use int or int64_t instead of size_t --- mlx/array.cpp | 15 ++++++++------- mlx/array.h | 14 +++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index a05e8dfa7b..23b322a7bd 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -44,11 +44,11 @@ std::vector array::make_arrays( const std::shared_ptr& primitive, const std::vector& inputs) { std::vector outputs; - for (size_t i = 0; i < shapes.size(); ++i) { + for (int i = 0; i < std::ssize(shapes); ++i) { outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs); } // For each node in |outputs|, its siblings are the other nodes. - for (size_t i = 0; i < outputs.size(); ++i) { + for (int i = 0; i < std::ssize(outputs); ++i) { auto siblings = outputs; siblings.erase(siblings.begin() + i); outputs[i].set_siblings(std::move(siblings), i); @@ -145,8 +145,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) { array_desc_->data_size = size(); array_desc_->flags.contiguous = true; array_desc_->flags.row_contiguous = true; - auto max_dim = std::max_element(shape().begin(), shape().end()); - array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim; + auto max_dim = + static_cast(*std::max_element(shape().begin(), shape().end())); + array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim; } void array::set_data( @@ -192,7 +193,7 @@ array::~array() { } // Break circular reference for non-detached arrays with siblings - if (auto n = siblings().size(); n > 0) { + if (auto n = std::ssize(siblings()); n > 0) { bool do_detach = true; // If all siblings have siblings.size() references except // the one we are currently destroying (which has siblings.size() + 1) @@ -274,7 +275,7 @@ array::ArrayDesc::~ArrayDesc() { ad.inputs.clear(); for (auto& [_, a] : input_map) { bool is_deletable = - (a.array_desc_.use_count() <= a.siblings().size() + 1); + (a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1); // An array with siblings is deletable only if all of its siblings // are deletable for (auto& s : a.siblings()) { @@ -283,7 +284,7 @@ array::ArrayDesc::~ArrayDesc() { } int is_input = (input_map.find(s.id()) != input_map.end()); is_deletable &= - s.array_desc_.use_count() <= a.siblings().size() + is_input; + s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input; } if (is_deletable) { for_deletion.push_back(std::move(a.array_desc_)); diff --git a/mlx/array.h b/mlx/array.h index 4e9a5ae63a..4be4835e17 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -81,22 +81,22 @@ class array { } /** The size of the array's datatype in bytes. */ - size_t itemsize() const { + int itemsize() const { return size_of(dtype()); } /** The number of elements in the array. */ - size_t size() const { + int64_t size() const { return array_desc_->size; } /** The number of bytes in the array. */ - size_t nbytes() const { + int64_t nbytes() const { return size() * itemsize(); } /** The number of dimensions of the array. */ - size_t ndim() const { + int ndim() const { return array_desc_->shape.size(); } @@ -329,7 +329,7 @@ class array { * corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``. * Note, ``data_size`` is in units of ``item_size`` (not bytes). **/ - size_t data_size() const { + int64_t data_size() const { return array_desc_->data_size; } @@ -340,7 +340,7 @@ class array { return array_desc_->data->buffer; } - size_t buffer_size() const { + int64_t buffer_size() const { return allocator::allocator().size(buffer()); } @@ -530,7 +530,7 @@ array::array( Shape shape, Dtype dtype /* = TypeToDtype() */) : array_desc_(std::make_shared(std::move(shape), dtype)) { - if (data.size() != size()) { + if (std::ssize(data) != size()) { throw std::invalid_argument( "Data size and provided shape mismatch in array construction."); } From 26f7155537e47cb9c4999dbe33bc6107355ea356 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 16:06:10 -0700 Subject: [PATCH 02/30] SmallVector: keep sizes small (int) --- mlx/small_vector.h | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/mlx/small_vector.h b/mlx/small_vector.h index 143101c82f..cf3467cbc9 100644 --- a/mlx/small_vector.h +++ b/mlx/small_vector.h @@ -121,10 +121,10 @@ class SmallVector { std::initializer_list init, const Allocator& allocator = Allocator()) : allocator_(allocator) { - if (init.size() > capacity()) { + if (static_cast(init.size()) > capacity()) { grow(init.size()); } - assert(capacity() >= init.size()); // sanity check + assert(capacity() >= static_cast(init.size())); // sanity check std::uninitialized_move(init.begin(), init.end(), begin_); end_ = begin_ + init.size(); } @@ -132,7 +132,7 @@ class SmallVector { template >> SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator()) : allocator_(allocator) { - size_t size = std::distance(begin, end); + int size = std::distance(begin, end); if (size > capacity()) { grow(size); } @@ -164,7 +164,7 @@ class SmallVector { if (this == &other) { return *this; } - size_t other_size = other.size(); + int other_size = other.size(); if (capacity() < other_size) { // Create large-enough heap-allocated storage. free_storage(); @@ -273,13 +273,13 @@ class SmallVector { return std::make_reverse_iterator(begin_); } - size_t size() const { + int size() const { return end_ - begin_; } bool empty() const { return end_ == begin_; } - size_t capacity() const { + int capacity() const { return end_of_storage_ - begin_; } @@ -301,21 +301,21 @@ class SmallVector { return end_[-1]; } - T& at(size_t index) { + T& at(int index) { if (index >= size()) { throw std::out_of_range("SmallVector out of range."); } return begin_[index]; } - const T& at(size_t index) const { + const T& at(int index) const { return const_cast(this)->at(index); } - T& operator[](size_t index) { + T& operator[](int index) { assert(size() > index); return begin_[index]; } - const T& operator[](size_t index) const { + const T& operator[](int index) const { return const_cast(this)->operator[](index); } @@ -333,7 +333,7 @@ class SmallVector { emplace_back(std::move(x)); } - void pop_back(size_t count = 1) { + void pop_back(int count = 1) { assert(size() >= count); end_ -= count; std::destroy_n(end_, count); @@ -400,7 +400,7 @@ class SmallVector { return erase(pos, pos + 1); } - void resize(size_t new_size) { + void resize(int new_size) { if (new_size > capacity()) { grow(new_size); } @@ -415,7 +415,7 @@ class SmallVector { end_ = new_end; } - void resize(size_t new_size, const T& initial_value) { + void resize(int new_size, const T& initial_value) { if (new_size > capacity()) { grow(new_size); } @@ -428,7 +428,7 @@ class SmallVector { end_ = new_end; } - void reserve(size_t new_capacity) { + void reserve(int new_capacity) { if (new_capacity > capacity()) { grow(new_capacity); } @@ -443,8 +443,8 @@ class SmallVector { private: // Grows the backing store by a factor of two, and at least to {min_capacity}. // TODO: Move to private after removing external code using this method. - MLX_NOINLINE void grow(size_t min_capacity = 0) { - size_t new_capacity = std::max(min_capacity, 2 * capacity()); + MLX_NOINLINE void grow(int min_capacity = 0) { + int new_capacity = std::max(min_capacity, 2 * capacity()); // Round up to power of 2. new_capacity--; new_capacity |= new_capacity >> 1; @@ -452,9 +452,6 @@ class SmallVector { new_capacity |= new_capacity >> 4; new_capacity |= new_capacity >> 8; new_capacity |= new_capacity >> 16; - if constexpr (sizeof(size_t) == sizeof(uint64_t)) { - new_capacity |= new_capacity >> 32; - } new_capacity++; T* new_storage = allocator_.allocate(new_capacity); From 953b2f5be2026cef23b483408ab29c27bccd162b Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 16:11:32 -0700 Subject: [PATCH 03/30] WIP --- mlx/compile.cpp | 30 +++--- mlx/einsum.cpp | 46 +++++----- mlx/export.cpp | 19 ++-- mlx/fast.cpp | 14 +-- mlx/fft.cpp | 6 +- mlx/ops.cpp | 52 +++++------ mlx/primitives.cpp | 221 +++++++++++++++++++++++---------------------- mlx/primitives.h | 15 +-- mlx/random.h | 1 + mlx/scheduler.h | 4 +- mlx/types/bf16.h | 3 - mlx/utils.h | 4 +- 12 files changed, 208 insertions(+), 207 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index d762c8d15d..b09095896a 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -194,7 +194,7 @@ const char* Compiled::name() const { } std::vector Compiled::output_shapes(const std::vector& inputs) { - size_t nd = 0; + int nd = 0; for (auto& in : inputs) { nd = std::max(nd, in.ndim()); } @@ -256,7 +256,7 @@ void merge(array& dst, array& src, ParentsMap& parents_map) { auto sources = src.outputs(); auto dests = dst.outputs(); // For each src parent, point it to the corresponding dst - for (int i = 0; i < sources.size(); ++i) { + for (int i = 0; i < std::ssize(sources); ++i) { merge_one(dests[i], sources[i], parents_map); } } @@ -327,7 +327,7 @@ class CompilerCache { if (in1.size() != in2.size()) { return false; } - for (size_t i = 0; i < in1.size(); ++i) { + for (int i = 0; i < std::ssize(in1); ++i) { if (in1[i].ndim() != in2[i].ndim()) { return false; } @@ -399,7 +399,7 @@ compile_trace( // Run the function on placeholder inputs // to get compute graph std::vector tracer_inputs; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {}); in.set_tracer(true); tracer_inputs.push_back(std::move(in)); @@ -420,7 +420,7 @@ std::pair, ParentsMap> compile_dfs( std::unordered_set original_input_set; std::unordered_map>> parents_map; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { input_set.insert(inputs[i].id()); original_input_set.insert(original_inputs[i].id()); } @@ -436,7 +436,7 @@ std::pair, ParentsMap> compile_dfs( if (cache.find(id) != cache.end()) { return; } - for (int i = 0; i < a.inputs().size(); i++) { + for (int i = 0; i < std::ssize(a.inputs()); i++) { auto& in = a.inputs()[i]; parents_map[in.id()].push_back({a, i}); for (auto& s : a.siblings()) { @@ -534,7 +534,7 @@ void compile_simplify( return false; } - for (int i = 0; i < a.inputs().size(); i++) { + for (int i = 0; i < std::ssize(a.inputs()); i++) { if (a.inputs()[i].id() != b.inputs()[i].id()) { return false; } @@ -599,7 +599,7 @@ void compile_simplify( auto maybe_merge_parents = [&](auto& a) { auto parents = parents_map.find(a.id()); if (parents != parents_map.end()) { - auto N = parents->second.size(); + auto N = std::ssize(parents->second); std::vector mask(N, false); auto try_merge = [&](int dst_idx, int src_idx) { @@ -642,11 +642,11 @@ void compile_simplify( it->second.push_back(i); } for (auto& [_, group] : dst_map) { - for (int i = 0; i < group.size(); ++i) { + for (int i = 0; i < std::ssize(group); ++i) { if (mask[group[i]]) { continue; } - for (int j = i + 1; j < group.size(); ++j) { + for (int j = i + 1; j < std::ssize(group); ++j) { if (mask[group[j]]) { continue; } @@ -847,7 +847,7 @@ void compile_fuse( std::vector old_outputs; // Add to global cache and add any global outputs to outputs // of new primitive - for (int j = 0; j < fused_tape.size() - 1; ++j) { + for (int j = 0; j < std::ssize(fused_tape) - 1; ++j) { auto& f = fused_tape[j]; if (output_map.find(f.id()) != output_map.end()) { old_outputs.push_back(f); @@ -903,7 +903,7 @@ void compile_fuse( new_tape.push_back(compiled_outputs.back()); // Replace inputs old parents with compiled_outputs - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { auto& pairs = parents_map[inputs[i].id()]; pairs.erase( std::remove_if( @@ -918,7 +918,7 @@ void compile_fuse( // - Update outputs parents to point to compiled outputs // - Update any overall graph outputs to be compiled outputs - for (int o = 0; o < old_outputs.size(); ++o) { + for (int o = 0; o < std::ssize(old_outputs); ++o) { merge_one(compiled_outputs[o], old_outputs[o], parents_map); if (auto it = output_map.find(old_outputs[o].id()); it != output_map.end()) { @@ -943,7 +943,7 @@ std::vector compile_replace( const std::vector& inputs, bool shapeless) { std::unordered_map trace_to_real; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); } @@ -989,7 +989,7 @@ std::vector compile_replace( } auto real_out = array::make_arrays( std::move(shapes), types, a.primitive_ptr(), real_inputs); - for (int i = 0; i < trace_out.size(); ++i) { + for (int i = 0; i < std::ssize(trace_out); ++i) { trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])}); } } diff --git a/mlx/einsum.cpp b/mlx/einsum.cpp index 6290887728..5dad17ba7c 100644 --- a/mlx/einsum.cpp +++ b/mlx/einsum.cpp @@ -190,8 +190,8 @@ std::tuple, size_t, int> greedy_path( // Start by iterating over all possible combinations std::vector> pos_pairs; - for (int i = 0; i < inputs.size(); ++i) { - for (int j = i + 1; j < inputs.size(); ++j) { + for (int i = 0; i < std::ssize(inputs); ++i) { + for (int j = i + 1; j < std::ssize(inputs); ++j) { pos_pairs.emplace_back(i, j); } } @@ -200,13 +200,13 @@ std::tuple, size_t, int> greedy_path( std::vector possible_contractions; size_t path_cost = 0; int path_scaling = 0; - auto num_in = inputs.size(); + auto num_in = std::ssize(inputs); for (int i = 0; i < num_in - 1; ++i) { auto add_contraction = [&](int p1, int p2) { CharSet new_term; CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end()); contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end()); - for (int i = 0; i < inputs.size(); i++) { + for (int i = 0; i < std::ssize(inputs); i++) { if (i == p1 || i == p2) { continue; } @@ -321,7 +321,7 @@ std::tuple, size_t, int> greedy_path( } pos_pairs.clear(); - for (int i = 0; i < inputs.size() - 1; ++i) { + for (int i = 0; i < std::ssize(inputs) - 1; ++i) { pos_pairs.emplace_back(i, inputs.size() - 1); } path_cost += best.cost; @@ -360,7 +360,7 @@ array batch_tensordot( { auto a_shape = a.shape(); auto b_shape = b.shape(); - for (int i = 0; i < a_contract.size(); ++i) { + for (int i = 0; i < std::ssize(a_contract); ++i) { auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i])); a_shape[a_contract[i]] = d; b_shape[b_contract[i]] = d; @@ -430,7 +430,7 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) { std::string repeat_str; std::string no_repeat_str; std::unordered_map counts; - for (int i = 0; i < str.size(); ++i) { + for (int i = 0; i < std::ssize(str); ++i) { auto [it, _] = counts.insert({str[i], 0}); it->second++; } @@ -455,7 +455,7 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) { std::vector indices; int n_expand = repeats.size(); for (auto [c, v] : repeats) { - for (int i = 0; i < str.size(); ++i) { + for (int i = 0; i < std::ssize(str); ++i) { if (str[i] == c) { slice_sizes[i] = 1; axes.push_back(i); @@ -494,7 +494,7 @@ void preprocess_einsum_inputs( std::vector& operands, StreamOrDevice s) { // Collapse repeat indices - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { auto& in = inputs[i]; if (in.set.size() < in.str.size()) { operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s); @@ -514,10 +514,10 @@ void preprocess_einsum_inputs( auto inserted = counts.insert({c, 0}); inserted.first->second++; } - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { auto& in = inputs[i]; std::vector sum_axes; - for (int ax = 0; ax < in.str.size(); ++ax) { + for (int ax = 0; ax < std::ssize(in.str); ++ax) { if (counts[in.str[ax]] == 1) { sum_axes.push_back(ax); } @@ -549,12 +549,12 @@ array einsum_naive( } // Expand and transpose inputs as needed - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { int pos = positions[i]; auto& op = operands[pos]; // Add missing dimensions at the end - if (op.ndim() != char_to_ax.size()) { + if (op.ndim() != std::ssize(char_to_ax)) { auto shape = op.shape(); shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1); op = reshape(op, std::move(shape), s); @@ -597,7 +597,7 @@ array einsum_naive( // Multiply and sum auto out = operands[positions[0]]; - for (int i = 1; i < positions.size(); ++i) { + for (int i = 1; i < std::ssize(positions); ++i) { out = multiply(out, operands[positions[i]], s); } std::vector sum_axes; @@ -675,9 +675,9 @@ std::pair, PathInfo> einsum_path_helper( int operand_idx) { bool have_ellipsis = false; int cnt_before = 0, cnt_after = 0; - for (int i = 0; i < subscript.size(); i++) { + for (int i = 0; i < std::ssize(subscript); i++) { if (!isalpha(subscript[i])) { - if (i + 2 >= subscript.size() || subscript[i] != '.' || + if (i + 2 >= std::ssize(subscript) || subscript[i] != '.' || subscript[i + 1] != '.' || subscript[i + 2] != '.') { std::ostringstream msg; msg << "[" << fn_name << "] Subscripts must be letters, but got '" @@ -732,7 +732,7 @@ std::pair, PathInfo> einsum_path_helper( } }; - for (int i = 0; i < operands.size(); i++) { + for (int i = 0; i < std::ssize(operands); i++) { check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i], i); } check_letters_and_expand_ellipsis(out_subscript, nullptr, -1); @@ -747,12 +747,12 @@ std::pair, PathInfo> einsum_path_helper( std::unordered_map dim_map; std::vector inputs; - for (int i = 0; i < in_subscripts.size(); ++i) { + for (int i = 0; i < std::ssize(in_subscripts); ++i) { auto& in = in_subscripts[i]; CharSet in_set(in.begin(), in.end()); inputs.emplace_back(in, in_set); - if (in.size() != operands[i].ndim()) { + if (std::ssize(in) != operands[i].ndim()) { std::ostringstream msg; msg << "[" << fn_name << "] Invalid number of subscripts " << in.size() << " for input " << i << " with " << operands[i].ndim() @@ -763,7 +763,7 @@ std::pair, PathInfo> einsum_path_helper( // Check repeat subscripts are valid if (in_set.size() < in.size()) { std::unordered_map local_dims; - for (int j = 0; j < in.size(); ++j) { + for (int j = 0; j < std::ssize(in); ++j) { auto dim = operands[i].shape(j); auto inserted = local_dims.insert({in[j], dim}); if (!inserted.second) { @@ -778,7 +778,7 @@ std::pair, PathInfo> einsum_path_helper( } } - for (int j = 0; j < in.size(); j++) { + for (int j = 0; j < std::ssize(in); j++) { auto c = in[j]; auto dim = operands[i].shape(j); auto inserted = dim_map.insert({c, dim}); @@ -864,7 +864,7 @@ array einsum( std::vector a_contract; std::vector a_batch; std::vector a_concat; - for (int i = 0; i < in_a.str.size(); ++i) { + for (int i = 0; i < std::ssize(in_a.str); ++i) { auto c = in_a.str[i]; if (out.set.find(c) == out.set.end()) { // Not in the output, contraction @@ -887,7 +887,7 @@ array einsum( for (auto a_i : a_batch) { b_batch.push_back(in_b.str.find(in_a.str[a_i])); } - for (int i = 0; i < in_b.str.size(); ++i) { + for (int i = 0; i < std::ssize(in_b.str); ++i) { auto c = in_b.str[i]; if (out.set.find(c) != out.set.end() && in_a.set.find(c) == in_a.set.end()) { diff --git a/mlx/export.cpp b/mlx/export.cpp index 3448178e24..ce26141cd2 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -138,7 +138,7 @@ T deserialize(Reader& is) { T v; auto size = deserialize(is); v.reserve(size); - for (int i = 0; i < size; ++i) { + for (size_t i = 0; i < size; ++i) { v.push_back(deserialize(is)); } return v; @@ -487,11 +487,11 @@ struct FunctionTable { int n = 1; for (auto& [_, vec] : table) { for (auto& fun : vec) { - auto npos = fun.inputs.size() - fun.kwarg_keys.size(); + auto npos = std::ssize(fun.inputs) - std::ssize(fun.kwarg_keys); os << " " << n++ << ". Function with " << npos - << " positional inputs and " << fun.kwarg_keys.size() + << " positional inputs and " << std::ssize(fun.kwarg_keys) << " keyword inputs:\n"; - for (int j = 0; j < fun.inputs.size(); ++j) { + for (int j = 0; j < std::ssize(fun.inputs); ++j) { auto& in = fun.inputs[j]; if (j < npos) { os << " " << j + 1 << ": "; @@ -536,7 +536,7 @@ bool FunctionTable::match( }; int i = 0; - for (; i < args.size(); ++i) { + for (; i < std::ssize(args); ++i) { if (!match_inputs(args[i], fun.inputs[i])) { return false; } @@ -627,7 +627,8 @@ void FunctionExporter::export_with_callback( // Callback on the inputs callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}}); std::vector> keyword_inputs; - for (int i = inputs.size() - kwarg_keys.size(), j = 0; i < inputs.size(); + for (int i = std::ssize(inputs) - std::ssize(kwarg_keys), j = 0; + i < std::ssize(inputs); ++i, ++j) { keyword_inputs.emplace_back(kwarg_keys[j], namer.get_name(inputs[i])); } @@ -928,7 +929,7 @@ std::vector ImportedFunction::operator()( ftable->print_functions(msg); msg << "\nCalled with " << args.size() << " positional inputs and " << kwargs.size() << " keyword inputs:\n"; - for (int i = 0; i < args.size(); ++i) { + for (int i = 0; i < std::ssize(args); ++i) { auto& in = args[i]; msg << " " << i + 1 << ": " << in.shape() << " " << in.dtype() << "\n"; } @@ -970,7 +971,7 @@ ImportedFunction::ImportedFunction(const std::string& file) std::unordered_map array_map; auto trace_input_ids = deserialize>(is); auto trace_inputs = deserialize>(is); - for (int i = 0; i < trace_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(trace_inputs); ++i) { array_map.emplace(trace_input_ids[i], trace_inputs[i]); } auto trace_output_ids = deserialize>(is); @@ -1006,7 +1007,7 @@ ImportedFunction::ImportedFunction(const std::string& file) std::move(types), std::move(prim), std::move(inputs)); - for (int i = 0; i < arrays.size(); ++i) { + for (int i = 0; i < std::ssize(arrays); ++i) { auto sid = ids[i]; if (sid == id) { tape.push_back(arrays[i]); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 0f34aec939..e88527a8e9 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -13,11 +13,11 @@ std::vector Custom::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, - const std::vector& outputs) { + const std::vector& /* outputs */) { auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents); std::vector vjp_outs; - for (int i = 0, j = 0; i < vjps.size(); ++i) { - if (j < argnums.size() && i == argnums[j]) { + for (int i = 0, j = 0; i < std::ssize(vjps); ++i) { + if (j < std::ssize(argnums) && i == argnums[j]) { vjp_outs.push_back(vjps[i]); j++; } @@ -30,8 +30,8 @@ std::vector Custom::jvp( const std::vector& tangents, const std::vector& argnums) { std::vector all_tangents; - for (int i = 0, j = 0; i < primals.size(); i++) { - if (j < argnums.size() && i == argnums[j]) { + for (int i = 0, j = 0; i < std::ssize(primals); i++) { + if (j < std::ssize(argnums) && i == argnums[j]) { all_tangents.emplace_back(tangents[j++]); } else { all_tangents.emplace_back(zeros_like(primals[i])); @@ -536,7 +536,7 @@ std::vector RoPE::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, - const std::vector& outputs) { + const std::vector& /* outputs */) { auto s = stream(); auto fallback = [dims = dims_, traditional = traditional_, @@ -635,7 +635,7 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - const size_t batch_dim = queries.shape(0); + const int batch_dim = queries.shape(0); for (const auto& tensor : {keys, values}) { if (tensor.shape(0) != batch_dim) { std::ostringstream msg; diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 6510faec1e..33f5e763a6 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -20,14 +20,14 @@ array fft_impl( throw std::invalid_argument( "[fftn] Requires array with at least one dimension."); } - if (n.size() != axes.size()) { + if (n.size() != std::ssize(axes)) { throw std::invalid_argument("[fftn] Shape and axes have different sizes."); } if (axes.empty()) { return a; } - std::vector valid_axes; + std::vector valid_axes; for (int ax : axes) { valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax); } @@ -59,7 +59,7 @@ array fft_impl( } auto in_shape = a.shape(); - for (int i = 0; i < valid_axes.size(); ++i) { + for (int i = 0; i < std::ssize(valid_axes); ++i) { in_shape[valid_axes[i]] = n[i]; } if (real && inverse) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 30e934f826..e1f2645745 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -390,7 +390,7 @@ array unflatten( throw std::invalid_argument(msg.str()); } - size_t size = 1; + int64_t size = 1; int infer_idx = -1; for (int i = 0; i < shape.size(); ++i) { if (shape[i] == -1) { @@ -687,10 +687,10 @@ void normalize_dynamic_slice_inputs( << "."; throw std::invalid_argument(msg.str()); } - if (start.size() != axes.size()) { + if (start.size() != std::ssize(axes)) { std::ostringstream msg; msg << prefix << " Number of starting indices " << start.size() - << " does not match number of axes " << axes.size() << "."; + << " does not match number of axes " << std::ssize(axes) << "."; throw std::invalid_argument(msg.str()); } if (!issubdtype(start.dtype(), integer)) { @@ -847,7 +847,7 @@ array slice_update( // Broadcast update with unspecified axes auto up_shape = update.shape(); - auto dim_diff = std::max(src.ndim() - update.ndim(), size_t(0)); + auto dim_diff = std::max(src.ndim() - update.ndim(), 0); up_shape.insert( up_shape.begin(), src.shape().begin(), src.shape().begin() + dim_diff); for (int d = dim_diff; d < src.ndim(); ++d) { @@ -957,7 +957,7 @@ std::vector meshgrid( "[meshgrid] Invalid indexing value. Valid values are 'xy' and 'ij'."); } - auto ndim = arrays.size(); + auto ndim = std::ssize(arrays); std::vector outputs; for (int i = 0; i < ndim; ++i) { Shape shape(ndim, 1); @@ -1135,10 +1135,10 @@ array tile( std::vector reps, StreamOrDevice s /* = {} */) { auto shape = arr.shape(); - if (reps.size() < shape.size()) { + if (std::ssize(reps) < shape.size()) { reps.insert(reps.begin(), shape.size() - reps.size(), 1); } - if (reps.size() > shape.size()) { + if (std::ssize(reps) > shape.size()) { shape.insert(shape.begin(), reps.size() - shape.size(), 1); } @@ -1162,7 +1162,7 @@ array tile( array edge_pad( const array& a, - const std::vector& axes, + const std::vector& /* axes */, const Shape& low_pad_size, const Shape& high_pad_size, const Shape& out_shape, @@ -1214,17 +1214,17 @@ array pad( const array& pad_value /*= array(0)*/, const std::string& mode /*= "constant"*/, StreamOrDevice s /* = {}*/) { - if (axes.size() != low_pad_size.size() || - axes.size() != high_pad_size.size()) { + if (std::ssize(axes) != low_pad_size.size() || + std::ssize(axes) != high_pad_size.size()) { std::ostringstream msg; msg << "Invalid number of padding sizes passed to pad " - << "with axes of size " << axes.size(); + << "with axes of size " << std::ssize(axes); throw std::invalid_argument(msg.str()); } auto out_shape = a.shape(); - for (int i = 0; i < axes.size(); i++) { + for (int i = 0; i < std::ssize(axes); i++) { if (low_pad_size[i] < 0) { std::ostringstream msg; msg << "Invalid low padding size (" << low_pad_size[i] @@ -1365,7 +1365,7 @@ array transpose( for (auto& ax : axes) { ax = ax < 0 ? ax + a.ndim() : ax; } - if (axes.size() != a.ndim()) { + if (std::ssize(axes) != a.ndim()) { std::ostringstream msg; msg << "[transpose] Recived " << axes.size() << " axes for array with " << a.ndim() << " dimensions."; @@ -1387,7 +1387,7 @@ array transpose( shape[ax] = 1; } - for (int i = 0; i < axes.size(); ++i) { + for (int i = 0; i < std::ssize(axes); ++i) { shape[i] = a.shape()[axes[i]]; } return array( @@ -1444,7 +1444,7 @@ std::vector broadcast_arrays( auto shape = BroadcastAxes::output_shape(inputs, ignore_axes); auto check_and_get_shape = [&shape, &ignore_axes](const array& in) { auto out_shape = shape; - for (int i = 0; i < ignore_axes.size(); ++i) { + for (int i = 0; i < std::ssize(ignore_axes); ++i) { auto ax = ignore_axes[i]; auto pos_ax = in.ndim() + ax; if (pos_ax < 0 || pos_ax > in.ndim() || @@ -1478,7 +1478,7 @@ std::vector broadcast_arrays( stop_grad_inputs.push_back(stop_gradient(in, s)); } - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { auto& in = inputs[i]; auto out_shape = check_and_get_shape(in); if (in.shape() == out_shape) { @@ -1486,7 +1486,7 @@ std::vector broadcast_arrays( } else { // broadcasted array goes first followed by other stopgrad inputs std::vector p_inputs = {in}; - for (int j = 0; j < inputs.size(); ++j) { + for (int j = 0; j < std::ssize(inputs); ++j) { if (j == i) { continue; } @@ -1530,14 +1530,14 @@ std::vector broadcast_arrays( for (auto& in : inputs) { stop_grad_inputs.push_back(stop_gradient(in, s)); } - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { auto& in = inputs[i]; if (in.shape() == shape) { outputs.push_back(in); } else { // broadcasted array goes first followed by other stopgrad inputs std::vector p_inputs = {in}; - for (int j = 0; j < inputs.size(); ++j) { + for (int j = 0; j < std::ssize(inputs); ++j) { if (j == i) { continue; } @@ -1961,7 +1961,7 @@ array median( auto dtype = at_least_float(a.dtype()); std::vector transpose_axes; for (int i = 0, j = 0; i < a.ndim(); ++i) { - if (j < sorted_axes.size() && i == sorted_axes[j]) { + if (j < std::ssize(sorted_axes) && i == sorted_axes[j]) { j++; continue; } @@ -3010,7 +3010,7 @@ array gather( const Shape& slice_sizes, StreamOrDevice s /* = {} */) { // Checks that indices, dimensions, and slice_sizes are all valid - if (indices.size() > a.ndim()) { + if (std::ssize(indices) > a.ndim()) { std::ostringstream msg; msg << "[gather] Too many index arrays. Got " << indices.size() << " index arrays for input with " << a.ndim() << " dimensions."; @@ -3312,7 +3312,7 @@ array scatter( Scatter::ReduceType mode, StreamOrDevice s) { // Checks that indices, dimensions, and slice_sizes are all valid - if (indices.size() > a.ndim()) { + if (std::ssize(indices) > a.ndim()) { std::ostringstream msg; msg << "[scatter] Too many index arrays. Got " << indices.size() << " index arrays for input with " << a.ndim() << " dimensions."; @@ -3820,7 +3820,7 @@ array conv_transpose_general( StreamOrDevice s) { std::vector padding_lo(padding.size()); std::vector padding_hi(padding.size()); - for (int i = 0; i < padding.size(); ++i) { + for (int i = 0; i < std::ssize(padding); ++i) { int wt_size = 1 + dilation[i] * (weight.shape(1 + i) - 1); padding_lo[i] = wt_size - padding[i] - 1; @@ -4632,7 +4632,7 @@ array tensordot( int csize = 1; auto x = a; auto y = b; - for (int i = 0; i < axes_a.size(); i++) { + for (int i = 0; i < std::ssize(axes_a); i++) { if (x.shape(axes_a.at(i)) == y.shape(axes_b.at(i))) { csize *= x.shape(axes_a.at(i)); } else { @@ -5560,7 +5560,7 @@ array roll( return a; } - if (shift.size() < axes.size()) { + if (shift.size() < std::ssize(axes)) { std::ostringstream msg; msg << "[roll] At least one shift value per axis is required, " << shift.size() << " provided for " << axes.size() << " axes."; @@ -5568,7 +5568,7 @@ array roll( } array result = a; - for (int i = 0; i < axes.size(); i++) { + for (int i = 0; i < std::ssize(axes); i++) { int ax = axes[i]; if (ax < 0) { ax += a.ndim(); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0b335e7652..152487de63 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -242,16 +242,16 @@ std::pair, std::vector> Abs::vmap( } std::vector Add::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { return { tangents.size() > 1 ? add(tangents[0], tangents[1], stream()) : tangents[0]}; } std::vector Add::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, const std::vector& argnums, const std::vector&) { @@ -315,7 +315,7 @@ std::vector AddMM::jvp( const std::vector& tangents, const std::vector& argnums) { std::vector jvp; - for (int i = 0; i < argnums.size(); ++i) { + for (int i = 0; i < std::ssize(argnums); ++i) { auto arg = argnums[i]; if (arg == 0) { if (jvp.empty()) { @@ -692,7 +692,7 @@ std::vector ArgSort::jvp( std::vector AsType::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector&) { if (cotangents[0].dtype() != dtype_) { throw std::invalid_argument( @@ -702,9 +702,9 @@ std::vector AsType::vjp( } std::vector AsType::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { return {astype(tangents[0], dtype_, stream())}; } @@ -752,7 +752,7 @@ std::vector AsStrided::vjp( std::vector AsStrided::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { assert(primals.size() == 1); return {as_strided(tangents[0], shape_, strides_, offset_, stream())}; @@ -827,9 +827,9 @@ std::vector Broadcast::vjp( } std::vector Broadcast::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { return {array( shape_, tangents[0].dtype(), @@ -858,7 +858,7 @@ bool Broadcast::is_equivalent(const Primitive& other) const { Shape Broadcast::output_shape(const std::vector& inputs) { auto shape = inputs[0].shape(); - for (int i = 1; i < inputs.size(); ++i) { + for (int i = 1; i < std::ssize(inputs); ++i) { shape = broadcast_shapes(shape, inputs[i].shape()); } return shape; @@ -886,7 +886,7 @@ std::vector BroadcastAxes::vjp( std::vector BroadcastAxes::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { return {array( output_shape(primals, ignore_axes_), tangents[0].dtype(), @@ -895,8 +895,8 @@ std::vector BroadcastAxes::jvp( } std::pair, std::vector> BroadcastAxes::vmap( - const std::vector& inputs, - const std::vector& axes) { + const std::vector& /* inputs */, + const std::vector& /* axes */) { throw std::invalid_argument("[BroadcastAxes] VMAP NYI"); } @@ -938,7 +938,7 @@ std::vector Ceil::vjp( std::vector Ceil::jvp( const std::vector& primals, - const std::vector& tangents, + const std::vector& /* tangents */, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); @@ -1072,8 +1072,8 @@ std::vector Concatenate::jvp( }); std::vector vals; - for (int i = 0, j = 0; i < primals.size(); ++i) { - if (j < argnums.size() && argnums[argidx[j]] == i) { + for (int i = 0, j = 0; i < std::ssize(primals); ++i) { + if (j < std::ssize(argnums) && argnums[argidx[j]] == i) { vals.push_back(tangents[argidx[j++]]); } else { vals.push_back(zeros_like(primals[i], stream())); @@ -1089,7 +1089,7 @@ std::pair, std::vector> Concatenate::vmap( int first_vmap = -1; // Find the first vmapped input - for (int i = 0; i < axes.size(); i++) { + for (int i = 0; i < std::ssize(axes); i++) { if (axes[i] >= 0) { out_ax = axes[i]; first_vmap = i; @@ -1107,7 +1107,7 @@ std::pair, std::vector> Concatenate::vmap( std::vector t_inputs; int axis = axis_ + (axis_ >= out_ax); auto cat_shape = inputs[first_vmap].shape(); - for (int i = 0; i < axes.size(); i++) { + for (int i = 0; i < std::ssize(axes); i++) { if (axes[i] >= 0) { if (out_ax != axes[i]) { t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream())); @@ -1132,7 +1132,7 @@ bool Concatenate::is_equivalent(const Primitive& other) const { std::vector Concatenate::output_shapes( const std::vector& inputs) { auto shape = inputs[0].shape(); - for (int i = 1; i < inputs.size(); ++i) { + for (int i = 1; i < std::ssize(inputs); ++i) { shape[axis_] += inputs[i].shape(axis_); } return {std::move(shape)}; @@ -1272,28 +1272,29 @@ Shape Convolution::conv_out_shape( int spatial_dims = in_shape.size() - 2; - if (strides.size() != spatial_dims) { + if (std::ssize(strides) != spatial_dims) { std::ostringstream msg; msg << "[conv] Invalid strides " << strides << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } - if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) { + if (std::ssize(pads_lo) != spatial_dims || + std::ssize(pads_hi) != spatial_dims) { std::ostringstream msg; msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } - if (kernel_dilation.size() != spatial_dims) { + if (std::ssize(kernel_dilation) != spatial_dims) { std::ostringstream msg; msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } - if (input_dilation.size() != spatial_dims) { + if (std::ssize(input_dilation) != spatial_dims) { std::ostringstream msg; msg << "[conv] Invalid input dilation " << input_dilation << " for " << spatial_dims << "D convolution."; @@ -1386,7 +1387,7 @@ std::vector Convolution::vjp( std::vector padding_lo = padding_lo_; std::vector padding_hi = padding_hi_; - for (int i = 0; i < padding_lo.size(); ++i) { + for (int i = 0; i < std::ssize(padding_lo); ++i) { int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); padding_lo[i] = wt_size - padding_lo_[i] - 1; @@ -1440,7 +1441,7 @@ std::vector Convolution::vjp( else if (a == 1) { bool no_dilation = true; - for (int i = 0; i < input_dilation_.size(); i++) { + for (int i = 0; i < std::ssize(input_dilation_); i++) { no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1); } @@ -1451,7 +1452,7 @@ std::vector Convolution::vjp( } else { auto padding_hi = padding_lo_; - for (int i = 0; i < padding_hi.size(); ++i) { + for (int i = 0; i < std::ssize(padding_hi); ++i) { int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); @@ -1699,11 +1700,11 @@ std::vector Depends::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, - const std::vector& outputs) { + const std::vector& /* outputs */) { std::vector vjps; for (auto arg : argnums) { - if (arg < cotangents.size()) { + if (arg < std::ssize(cotangents)) { vjps.push_back(cotangents[arg]); } else { vjps.push_back(zeros_like(primals[arg])); @@ -1737,7 +1738,7 @@ std::vector Divide::vjp( std::vector DivMod::vjp( const std::vector& primals, - const std::vector& cotangents, + const std::vector& /* cotangents */, const std::vector& argnums, const std::vector&) { std::vector vjps; @@ -1749,8 +1750,8 @@ std::vector DivMod::vjp( std::vector DivMod::jvp( const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* tangents */, + const std::vector& /* argnums */) { return {zeros_like(primals[0], stream())}; } @@ -1848,7 +1849,7 @@ std::pair, std::vector> Equal::vmap( std::vector Equal::vjp( const std::vector& primals, - const std::vector& cotangents, + const std::vector& /* cotangents */, const std::vector& argnums, const std::vector&) { std::vector vjps; @@ -1860,8 +1861,8 @@ std::vector Equal::vjp( std::vector Equal::jvp( const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* tangents */, + const std::vector& /* argnums */) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); return {zeros(shape, bool_, stream())}; } @@ -1899,7 +1900,7 @@ std::pair, std::vector> Erf::vmap( std::vector ErfInv::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector& outputs) { auto dtype = primals[0].dtype(); auto scale = @@ -1931,9 +1932,9 @@ std::pair, std::vector> ErfInv::vmap( } std::vector Exp::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector& outputs) { return {multiply(cotangents[0], outputs[0], stream())}; } @@ -1956,9 +1957,9 @@ std::pair, std::vector> Exp::vmap( } std::vector Expm1::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector& outputs) { return {multiply( cotangents[0], @@ -2281,7 +2282,7 @@ std::vector Floor::vjp( std::vector Floor::jvp( const std::vector& primals, - const std::vector& tangents, + const std::vector& /* tangents */, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); @@ -2346,7 +2347,7 @@ std::pair, std::vector> Gather::vmap( // Reorder all the index arrays so the vmap axis is in the same spot. if (indices_vmapped) { - for (int i = 1; i < axes.size(); ++i) { + for (int i = 1; i < std::ssize(axes); ++i) { if (out_ax != axes[i] && axes[i] >= 0) { indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream()); } else if (axes[i] < 0) { @@ -2515,7 +2516,7 @@ std::pair, std::vector> Greater::vmap( std::vector Greater::vjp( const std::vector& primals, - const std::vector& cotangents, + const std::vector& /* cotangents */, const std::vector& argnums, const std::vector&) { std::vector vjps; @@ -2527,8 +2528,8 @@ std::vector Greater::vjp( std::vector Greater::jvp( const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* tangents */, + const std::vector& /* argnums */) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); return {zeros(shape, bool_, stream())}; } @@ -2542,7 +2543,7 @@ std::pair, std::vector> GreaterEqual::vmap( std::vector GreaterEqual::vjp( const std::vector& primals, - const std::vector& cotangents, + const std::vector& /* cotangents */, const std::vector& argnums, const std::vector&) { std::vector vjps; @@ -2554,8 +2555,8 @@ std::vector GreaterEqual::vjp( std::vector GreaterEqual::jvp( const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* tangents */, + const std::vector& /* argnums */) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); return {zeros(shape, bool_, stream())}; } @@ -2599,7 +2600,7 @@ std::pair, std::vector> Less::vmap( std::vector Less::vjp( const std::vector& primals, - const std::vector& cotangents, + const std::vector& /* cotangents */, const std::vector& argnums, const std::vector&) { std::vector vjps; @@ -2611,8 +2612,8 @@ std::vector Less::vjp( std::vector Less::jvp( const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* tangents */, + const std::vector& /* argnums */) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); return {zeros(shape, bool_, stream())}; } @@ -2626,7 +2627,7 @@ std::pair, std::vector> LessEqual::vmap( std::vector LessEqual::vjp( const std::vector& primals, - const std::vector& cotangents, + const std::vector& /* cotangents */, const std::vector& argnums, const std::vector&) { std::vector vjps; @@ -2638,8 +2639,8 @@ std::vector LessEqual::vjp( std::vector LessEqual::jvp( const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* tangents */, + const std::vector& /* argnums */) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); return {zeros(shape, bool_, stream())}; } @@ -2748,7 +2749,7 @@ std::vector LogicalAnd::vjp( std::vector LogicalAnd::jvp( const std::vector& primals, - const std::vector& tangents, + const std::vector& /* tangents */, const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() <= 2); @@ -2780,7 +2781,7 @@ std::vector LogicalOr::vjp( std::vector LogicalOr::jvp( const std::vector& primals, - const std::vector& tangents, + const std::vector& /* tangents */, const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() <= 2); @@ -2859,7 +2860,7 @@ std::pair, std::vector> LogSumExp::vmap( std::vector LogSumExp::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector&) { assert(primals.size() == 1); assert(cotangents.size() == 1); @@ -2872,7 +2873,7 @@ std::vector LogSumExp::vjp( std::vector LogSumExp::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { assert(primals.size() == 1); assert(tangents.size() == 1); return {multiply( @@ -2920,7 +2921,7 @@ std::vector Matmul::jvp( const std::vector& tangents, const std::vector& argnums) { std::vector jvp; - for (int i = 0; i < argnums.size(); ++i) { + for (int i = 0; i < std::ssize(argnums); ++i) { auto arg = argnums[i]; if (arg == 0 && i == 0) { jvp.push_back(matmul(tangents[0], primals[1], stream())); @@ -3096,7 +3097,7 @@ std::vector Select::jvp( }; array jvp = jvp_fun(argnums[0]); - for (int i = 1; i < argnums.size(); i++) { + for (int i = 1; i < std::ssize(argnums); i++) { jvp = add(jvp, jvp_fun(argnums[i])); } return {jvp}; @@ -3173,7 +3174,7 @@ std::pair, std::vector> NotEqual::vmap( std::vector NotEqual::vjp( const std::vector& primals, - const std::vector& cotangents, + const std::vector& /* cotangents */, const std::vector& argnums, const std::vector&) { std::vector vjps; @@ -3185,14 +3186,14 @@ std::vector NotEqual::vjp( std::vector NotEqual::jvp( const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* tangents */, + const std::vector& /* argnums */) { auto shape = broadcast_shapes(primals[0].shape(), primals[1].shape()); return {zeros(shape, bool_, stream())}; } std::vector Pad::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, const std::vector& argnums, const std::vector&) { @@ -3213,7 +3214,7 @@ std::vector Pad::vjp( } std::vector Pad::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, const std::vector& argnums) { assert(argnums.size() == 1 && argnums[0] == 0); @@ -3229,8 +3230,8 @@ std::vector Pad::jvp( } std::pair, std::vector> Pad::vmap( - const std::vector& inputs, - const std::vector& axes) { + const std::vector& /* inputs */, + const std::vector& /* axes */) { throw std::runtime_error("Pad vmap is NYI."); } @@ -3244,7 +3245,7 @@ bool Pad::is_equivalent(const Primitive& other) const { std::vector Partition::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector&) { auto sort_idx = argpartition(primals[0], kth_, axis_, stream()); return {put_along_axis( @@ -3258,7 +3259,7 @@ std::vector Partition::vjp( std::vector Partition::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { assert(primals.size() == 1); assert(tangents.size() == 1); auto sort_idx = argpartition(primals[0], kth_, axis_, stream()); @@ -3344,8 +3345,8 @@ QuantizationMode string_to_quantization_mode(const std::string& mode) { } std::pair, std::vector> QuantizedMatmul::vmap( - const std::vector& inputs, - const std::vector& axes) { + const std::vector& /* inputs */, + const std::vector& /* axes */) { throw std::runtime_error("[QuantizedMatmul::vmap] NYI"); } @@ -3450,8 +3451,8 @@ std::vector QuantizedMatmul::output_shapes( } std::pair, std::vector> GatherQMM::vmap( - const std::vector& inputs, - const std::vector& axes) { + const std::vector& /* inputs */, + const std::vector& /* axes */) { throw std::runtime_error("GatherQMM::vmap NYI"); } @@ -3573,9 +3574,9 @@ std::vector GatherQMM::vjp( } std::vector GatherQMM::jvp( - const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* primals */, + const std::vector& /* tangents */, + const std::vector& /* argnums */) { throw std::runtime_error("GatherQMM::jvp NYI"); } @@ -3706,7 +3707,7 @@ bool Reshape::is_equivalent(const Primitive& other) const { } Shape Reshape::output_shape(const array& input, Shape shape) { - size_t size = 1; + int64_t size = 1; int infer_idx = -1; for (int i = 0; i < shape.size(); ++i) { if (shape[i] == -1) { @@ -3730,7 +3731,7 @@ Shape Reshape::output_shape(const array& input, Shape shape) { } // Check that the reshaping is valid - if (input.size() != size) { + if (std::ssize(input) != size) { std::ostringstream msg; msg << "[reshape] Cannot reshape array of size " << input.size() << " into shape " << shape << "."; @@ -3746,7 +3747,7 @@ std::vector Reshape::output_shapes(const std::vector& inputs) { std::vector Reduce::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector& outputs) { auto in = primals[0]; @@ -3776,7 +3777,7 @@ std::vector Reduce::vjp( // except the reduced over axes. int j = 0; for (int i = 0; i < in.ndim(); i++) { - if (j < axes_.size() && axes_[j] == i) { + if (j < std::ssize(axes_) && axes_[j] == i) { j++; } else { transpose_to.push_back(i); @@ -3788,7 +3789,7 @@ std::vector Reduce::vjp( } shape_flat.push_back(-1); transpose_back.resize(transpose_to.size()); - for (int i = 0; i < transpose_to.size(); i++) { + for (int i = 0; i < std::ssize(transpose_to); i++) { transpose_back[transpose_to[i]] = i; } } @@ -3886,7 +3887,7 @@ std::vector Round::vjp( std::vector Round::jvp( const std::vector& primals, - const std::vector& tangents, + const std::vector& /* tangents */, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); @@ -4021,7 +4022,7 @@ std::vector Scan::vjp( } std::vector Scan::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, const std::vector& argnums) { assert(tangents.size() == 1); @@ -4099,7 +4100,7 @@ std::vector Scatter::vjp( // Should never reach here throw std::invalid_argument(""); } - } else if (num == primals.size() - 1) { + } else if (num == std::ssize(primals) - 1) { switch (reduce_type_) { case Scatter::None: case Scatter::Sum: { @@ -4140,9 +4141,9 @@ std::vector Scatter::vjp( } std::vector Scatter::jvp( - const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* primals */, + const std::vector& /* tangents */, + const std::vector& /* argnums */) { throw std::runtime_error("[scatter] JVP not yet implemented"); } @@ -4167,7 +4168,7 @@ std::pair, std::vector> Scatter::vmap( inputs[0] = repeat(expand_dims(inputs[0], 0, stream()), vmap_size, 0, stream()); } - for (int i = 1; i < vmap_axes.size() - 1; ++i) { + for (int i = 1; i < std::ssize(vmap_axes) - 1; ++i) { // vmap axis for indices goes to 0 if (vmap_axes[i] >= 0) { inputs[i] = moveaxis(inputs[i], vmap_axes[i], 0, stream()); @@ -4303,7 +4304,7 @@ std::pair, std::vector> ScatterAxis::vmap( } auto v_in = inputs; - for (int i = 0; i < axes.size(); ++i) { + for (int i = 0; i < std::ssize(axes); ++i) { if (axes[i] >= 0) { // if out_ax >= 0 move axis o/w set out_ax if (out_ax != axes[i]) { @@ -4329,9 +4330,9 @@ bool ScatterAxis::is_equivalent(const Primitive& other) const { } std::vector Sigmoid::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector& outputs) { auto& s = outputs[0]; auto sprime = @@ -4369,7 +4370,7 @@ std::vector Sign::vjp( std::vector Sign::jvp( const std::vector& primals, - const std::vector& tangents, + const std::vector& /* tangents */, const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); @@ -4453,7 +4454,7 @@ std::pair, std::vector> Slice::vmap( std::vector Slice::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector&) { // Check inputs assert(primals.size() == 1); @@ -4465,7 +4466,7 @@ std::vector Slice::vjp( std::vector Slice::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { // Check inputs assert(primals.size() == 1); return {slice(tangents[0], start_indices_, end_indices_, strides_, stream())}; @@ -4562,7 +4563,7 @@ std::vector SliceUpdate::vjp( std::vector SliceUpdate::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { // Check inputs assert(primals.size() == 2); return {slice_update( @@ -4743,7 +4744,7 @@ std::pair, std::vector> Softmax::vmap( std::vector Softmax::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector& outputs) { assert(primals.size() == 1); assert(cotangents.size() == 1); @@ -4757,7 +4758,7 @@ std::vector Softmax::vjp( std::vector Softmax::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { assert(primals.size() == 1); assert(tangents.size() == 1); auto s = softmax(primals[0], std::vector{-1}, precise_, stream()); @@ -4793,7 +4794,7 @@ std::vector Sort::vjp( std::vector Sort::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { assert(primals.size() == 1); assert(tangents.size() == 1); auto sort_idx = argsort(primals[0], axis_, stream()); @@ -4816,17 +4817,17 @@ std::pair, std::vector> Split::vmap( } std::vector Split::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector&) { return {concatenate(cotangents, axis_, stream())}; } std::vector Split::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { return split(tangents[0], indices_, axis_, stream()); } @@ -4846,7 +4847,7 @@ std::vector Square::vjp( std::vector Square::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { assert(primals.size() == 1); assert(tangents.size() == 1); return {multiply( @@ -4866,7 +4867,7 @@ std::pair, std::vector> Square::vmap( std::vector Sqrt::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector& /* argnums */, const std::vector& outputs) { assert(primals.size() == 1); assert(cotangents.size() == 1); @@ -4919,7 +4920,7 @@ std::pair, std::vector> StopGradient::vmap( } std::vector Subtract::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, const std::vector& argnums, const std::vector&) { @@ -4935,7 +4936,7 @@ std::vector Subtract::vjp( } std::vector Subtract::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, const std::vector& argnums) { auto jvp_fun = [&](int i) { @@ -4994,7 +4995,7 @@ bool Squeeze::is_equivalent(const Primitive& other) const { Shape Squeeze::output_shape(const array& input, const std::vector& axes) { Shape shape; for (int i = 0, j = 0; i < input.ndim(); ++i) { - if (j < axes.size() && i == axes[j]) { + if (j < std::ssize(axes) && i == axes[j]) { j++; } else { shape.push_back(input.shape(i)); @@ -5409,7 +5410,7 @@ std::vector Transpose::vjp( assert(primals.size() == 1); assert(argnums.size() == 1); std::vector iaxes(axes_.size()); - for (int i = 0; i < axes_.size(); ++i) { + for (int i = 0; i < std::ssize(axes_); ++i) { iaxes[axes_[i]] = i; } return {transpose(cotangents[0], iaxes, stream())}; @@ -5418,7 +5419,7 @@ std::vector Transpose::vjp( std::vector Transpose::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector& /* argnums */) { assert(primals.size() == 1); assert(tangents.size() == 1); return {transpose(tangents[0], axes_, stream())}; @@ -5449,7 +5450,7 @@ bool Transpose::is_equivalent(const Primitive& other) const { std::vector Transpose::output_shapes(const std::vector& inputs) { auto& in = inputs[0]; Shape shape(in.ndim(), 0); - for (int i = 0; i < axes_.size(); ++i) { + for (int i = 0; i < std::ssize(axes_); ++i) { shape[i] = in.shape()[axes_[i]]; } return {std::move(shape)}; diff --git a/mlx/primitives.h b/mlx/primitives.h index 2a843a0e4a..a37124db92 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -31,9 +31,9 @@ return #PRIMITIVE; \ } -#define DEFINE_DEFAULT_IS_EQUIVALENT() \ - bool is_equivalent(const Primitive& other) const override { \ - return true; \ +#define DEFINE_DEFAULT_IS_EQUIVALENT() \ + bool is_equivalent(const Primitive& /* other */) const override { \ + return true; \ } #define DEFINE_INPUT_OUTPUT_SHAPE() \ @@ -104,7 +104,7 @@ class Primitive { virtual const char* name() const = 0; /** Equivalence check defaults to false unless overridden by the primitive */ - virtual bool is_equivalent(const Primitive& other) const { + virtual bool is_equivalent(const Primitive& /* other */) const { return false; } @@ -1071,7 +1071,7 @@ class FFT : public UnaryPrimitive { public: explicit FFT( Stream stream, - const std::vector& axes, + const std::vector& axes, bool inverse, bool real) : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {} @@ -1089,7 +1089,7 @@ class FFT : public UnaryPrimitive { } private: - std::vector axes_; + std::vector axes_; bool inverse_; bool real_; }; @@ -1526,7 +1526,8 @@ class NumberOfElements : public UnaryPrimitive { DEFINE_VMAP() DEFINE_NAME(NumberOfElements) bool is_equivalent(const Primitive& other) const override; - std::vector output_shapes(const std::vector& inputs) override { + std::vector output_shapes( + const std::vector& /* inputs */) override { return {{}}; } std::tuple, bool, Dtype> state() const { diff --git a/mlx/random.h b/mlx/random.h index 0dfdab7a14..c707889b95 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -89,6 +89,7 @@ inline array uniform( const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { + (void)s; return uniform(shape, float32, key); } diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 877fdd5f6a..65286bd677 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -103,7 +103,7 @@ class Scheduler { default_streams_.at(s.device.type) = s; } - void notify_new_task(const Stream& stream) { + void notify_new_task(const Stream& /* stream */) { { std::lock_guard lk(mtx); n_active_tasks_++; @@ -111,7 +111,7 @@ class Scheduler { completion_cv.notify_all(); } - void notify_task_completion(const Stream& stream) { + void notify_task_completion(const Stream& /* stream */) { { std::lock_guard lk(mtx); n_active_tasks_--; diff --git a/mlx/types/bf16.h b/mlx/types/bf16.h index 5951941747..3e1f9d9a72 100644 --- a/mlx/types/bf16.h +++ b/mlx/types/bf16.h @@ -24,9 +24,6 @@ struct _MLX_BFloat16 { // Default constructor _MLX_BFloat16() = default; - // Default copy constructor - _MLX_BFloat16(_MLX_BFloat16 const&) = default; - // Appease std::vector for being special _MLX_BFloat16& operator=(std::vector::reference x) { bits_ = x; diff --git a/mlx/utils.h b/mlx/utils.h index dbf79a71f5..a71e92a356 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -138,8 +138,8 @@ namespace env { int get_var(const char* name, int default_value); -inline int bfs_max_width() { - static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20); +inline unsigned int bfs_max_width() { + static unsigned int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20); return bfs_max_width_; } From 3d67b717a05764a510481a22d29c5931a5c19d06 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 16:43:18 -0700 Subject: [PATCH 04/30] the cpu simd case --- mlx/backend/cpu/simd/accelerate_simd.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index 9148310553..21c9d590d5 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -253,12 +253,12 @@ Simd pow(Simd base, Simd exp) { } else { Simd res = 1; // Raising an integer to a negative power is undefined - if (any(exp < 0)) { + if (any(exp < static_cast(0))) { return 0; } - while (any(exp > 0)) { + while (any(exp > static_cast(0))) { res = select((exp & 1) != 0, res * base, res); - base = select(exp > 0, base * base, base); + base = select(exp > static_cast(0), base * base, base); exp = exp >> 1; } return res; From 53525cba2394cd4351c20b6e746780d8f9415916 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 16:51:05 -0700 Subject: [PATCH 05/30] WIP --- mlx/fast_primitives.h | 35 ++++++++++++++++----------- mlx/transforms.cpp | 56 +++++++++++++++++++++---------------------- 2 files changed, 49 insertions(+), 42 deletions(-) diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index a8000485a0..52662158bd 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -46,8 +46,9 @@ class RMSNorm : public Custom { static bool use_fallback(Stream stream); - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -79,8 +80,9 @@ class RMSNormVJP : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -106,8 +108,9 @@ class LayerNorm : public Custom { static bool use_fallback(Stream s); - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -138,8 +141,9 @@ class LayerNormVJP : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -174,8 +178,9 @@ class RoPE : public Custom { static bool use_fallback(Stream s); - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -225,8 +230,9 @@ class ScaledDotProductAttention : public Custom { bool do_causal, Stream s); - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("NYI"); } @@ -320,8 +326,9 @@ class CustomKernel : public Primitive { is_precompiled_(is_precompiled), shared_memory_(shared_memory) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { + void eval_cpu( + const std::vector& /* inputs */, + std::vector& /* outputs */) override { throw std::runtime_error("Custom kernels only run on GPU."); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 3ec64feea8..ec58fd2bd6 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -89,7 +89,7 @@ array eval_impl(std::vector outputs, bool async) { auto& [a_ref, idx] = dfs.top(); auto& a = a_ref.get(); - if (idx < a.inputs().size()) { + if (idx < std::ssize(a.inputs())) { // Add an input, and continue auto& in = a.inputs()[idx++]; @@ -146,16 +146,16 @@ array eval_impl(std::vector outputs, bool async) { int max_width = env::bfs_max_width(); dfs = std::stack, int>>(); tape.push_back(synchronizer); - for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) { - auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i]; + for (int i = 0; !cache.empty() && (i < std::ssize(tape) || !dfs.empty());) { + auto& a = (i >= std::ssize(tape)) ? dfs.top().first.get() : tape[i]; int j = 0; - if (i >= tape.size()) { + if (i >= std::ssize(tape)) { j = dfs.top().second; dfs.pop(); } else { i++; } - for (; j < a.inputs().size(); ++j) { + for (; j < std::ssize(a.inputs()); ++j) { auto& in = a.inputs()[j]; if (in.status() != array::Status::unscheduled) { continue; @@ -163,7 +163,7 @@ array eval_impl(std::vector outputs, bool async) { // If the width limit is exceeded, push the array on the stack // and go down a level - if ((tape.size() - i) >= max_width) { + if ((std::ssize(tape) - i) >= max_width) { dfs.emplace(a, j); break; } @@ -343,14 +343,14 @@ std::pair, std::vector> vjp( // that have stop_gradient called on them int cotan_index = 0; std::vector> output_cotan_pairs; - for (int i = 0; i < outputs.size(); ++i) { + for (int i = 0; i < std::ssize(outputs); ++i) { auto& out = outputs[i]; if (out.has_primitive()) { if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) { continue; } } - if (cotan_index >= cotans.size()) { + if (cotan_index >= std::ssize(cotans)) { std::ostringstream msg; msg << "[vjp] Number of outputs to compute gradients for (" << outputs.size() << ") does not match number of cotangents (" @@ -374,11 +374,11 @@ std::pair, std::vector> vjp( // to the tape which need a gradient. std::unordered_set cache; std::unordered_set calc_grad; - for (int i = 0, j = 0; i < primals_.size(); ++i) { + for (int i = 0, j = 0; i < std::ssize(primals_); ++i) { auto& primal = primals_[i]; primal.set_tracer(false); cache.insert(primal.id()); - if (j < argnums.size() && argnums[j] == i) { + if (j < std::ssize(argnums) && argnums[j] == i) { j++; calc_grad.insert(primal.id()); } @@ -440,7 +440,7 @@ std::pair, std::vector> vjp( // Get the arguments whose gradients are needed std::vector argnums; - for (int i = 0; i < a.inputs().size(); ++i) { + for (int i = 0; i < std::ssize(a.inputs()); ++i) { if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) { argnums.push_back(i); } @@ -473,7 +473,7 @@ std::pair, std::vector> vjp( vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs); } // Accumulate the vector-jacobian products for each input - for (int i = 0; i < argnums.size(); ++i) { + for (int i = 0; i < std::ssize(argnums); ++i) { auto in_id = a.inputs()[argnums[i]].id(); if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) { cotan_it->second = add(cotan_it->second, vjps[i], s); @@ -528,7 +528,7 @@ std::pair, std::vector> jvp( throw std::invalid_argument( "[jvp] Number of inputs does not match number of tangents."); } - for (int i = 0; i < primals.size(); ++i) { + for (int i = 0; i < std::ssize(primals); ++i) { if (primals[i].shape() != tangents[i].shape()) { throw std::invalid_argument( "[jvp] Input shape does not match shape of tangent."); @@ -597,7 +597,7 @@ std::pair, std::vector> jvp( } std::unordered_map tan_map; - for (int i = 0; i < primals_.size(); ++i) { + for (int i = 0; i < std::ssize(primals_); ++i) { tan_map.insert({primals_[i].id(), tangents[i]}); } @@ -605,7 +605,7 @@ std::pair, std::vector> jvp( // Get the arguments used in the jvp std::vector argnums; std::vector tangents; - for (int i = 0; i < a.inputs().size(); ++i) { + for (int i = 0; i < std::ssize(a.inputs()); ++i) { if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) { argnums.push_back(i); tangents.push_back(it->second); @@ -614,7 +614,7 @@ std::pair, std::vector> jvp( auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums); auto outputs = a.outputs(); - for (int i = 0; i < jvps.size(); ++i) { + for (int i = 0; i < std::ssize(jvps); ++i) { tan_map.insert({outputs[i].id(), jvps[i]}); } } @@ -658,7 +658,7 @@ ValueAndGradFn value_and_grad( throw std::invalid_argument( "[grad] Repeat argument number not allowed in grad."); } - if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) { + if (*args.begin() < 0 || *args.rbegin() >= std::ssize(inputs)) { std::ostringstream msg; msg << "[grad] Invalid argument number for function with " << inputs.size() << " inputs."; @@ -668,7 +668,7 @@ ValueAndGradFn value_and_grad( auto gfun = [&fun](const std::vector& inputs) { auto outputs = fun(inputs); - for (int i = 1; i < outputs.size(); i++) { + for (int i = 1; i < std::ssize(outputs); i++) { auto& out = outputs[i]; auto s = out.has_primitive() ? out.primitive().stream() : default_stream(default_device()); @@ -701,7 +701,7 @@ std::pair, std::vector> vmap_trace( // Some error checking and get the vmap axis size size_t vmap_ax_size; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { if (in_axes[i] != -1) { if (inputs[i].ndim() == 0) { throw std::invalid_argument( @@ -717,7 +717,7 @@ std::pair, std::vector> vmap_trace( } } // Check that all vmapped axes have the same size - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { if (in_axes[i] != -1) { if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) { std::ostringstream msg; @@ -731,7 +731,7 @@ std::pair, std::vector> vmap_trace( // Run the function on placeholder inputs // to get the original graph std::vector s_inputs; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { if (in_axes[i] != -1) { auto shape = inputs[i].shape(); shape.erase(shape.begin() + in_axes[i]); @@ -759,7 +759,7 @@ std::vector vmap_replace( } int vmap_size = -1; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { if (in_axes[i] >= 0) { vmap_size = inputs[i].shape(in_axes[i]); break; @@ -772,7 +772,7 @@ std::vector vmap_replace( std::unordered_map> tmap; std::unordered_set needs_vmap; std::unordered_set cache; - for (int i = 0; i < s_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(s_inputs); ++i) { auto in = s_inputs[i]; if (in_axes[i] != -1) { tmap.insert({in.id(), {inputs[i], in_axes[i]}}); @@ -843,7 +843,7 @@ std::vector vmap_replace( // For each primitive's outputs add its id, the vout id and the vax auto outputs = a.outputs(); - for (int i = 0; i < v_outputs.size(); ++i) { + for (int i = 0; i < std::ssize(v_outputs); ++i) { tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}}); } } @@ -851,7 +851,7 @@ std::vector vmap_replace( // Populate the outputs and make sure all the output axes are // in the right place std::vector outputs; - for (int i = 0; i < s_outputs.size(); ++i) { + for (int i = 0; i < std::ssize(s_outputs); ++i) { if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) { auto& [out, vdim] = map_it->second; if (vdim != out_axes[i]) { @@ -995,7 +995,7 @@ std::function(const std::vector&)> custom_function( // using `fun` directly because we may not be able to fully reuse // the outputs of the forward pass. fun_vjp.value_or( - [fun](auto primals, auto cotangents, auto outputs) { + [fun](auto primals, auto cotangents, auto /* outputs */) { auto [__, vjps] = vjp(fun, primals, cotangents); return vjps; }), @@ -1009,8 +1009,8 @@ std::function(const std::vector&)> custom_function( // waste computation. fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) { std::vector all_tangents; - for (int i = 0, j = 0; i < primals.size(); i++) { - if (j < argnums.size() && i == argnums[j]) { + for (int i = 0, j = 0; i < std::ssize(primals); i++) { + if (j < std::ssize(argnums) && i == argnums[j]) { all_tangents.emplace_back(tangents[j++]); } else { all_tangents.emplace_back(zeros_like(primals[i])); From cacc3ab7fd8c9e1cbda978b1019c523c06e6104a Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 16:51:42 -0700 Subject: [PATCH 06/30] WIP (common) --- mlx/backend/common/common.cpp | 23 ++++++++++++----------- mlx/backend/common/compiled.cpp | 16 ++++++++-------- mlx/backend/common/load.cpp | 2 +- mlx/backend/common/utils.h | 2 +- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 2cda88a311..ea6f24a031 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -21,8 +21,8 @@ void AsStrided::eval(const std::vector& inputs, array& out) { // Compute the flags given the shape and strides bool row_contiguous = true, col_contiguous = true; - size_t r = 1, c = 1; - for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) { + int64_t r = 1, c = 1; + for (int i = std::ssize(strides_) - 1, j = 0; i >= 0; i--, j++) { row_contiguous &= (r == strides_[i]) || (shape_[i] == 1); col_contiguous &= (c == strides_[j]) || (shape_[j] == 1); r *= shape_[i]; @@ -60,7 +60,8 @@ void CustomTransforms::eval( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() > outputs.size()); - for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); + for (int i = 0, j = std::ssize(inputs) - std::ssize(outputs); + i < std::ssize(outputs); i++, j++) { outputs[i].copy_shared_buffer(inputs[j]); } @@ -70,7 +71,7 @@ void Depends::eval( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() > outputs.size()); - for (int i = 0; i < outputs.size(); i++) { + for (int i = 0; i < std::ssize(outputs); i++) { outputs[i].copy_shared_buffer(inputs[i]); } } @@ -206,11 +207,11 @@ void Split::eval( auto compute_new_flags = [](const auto& shape, const auto& strides, - size_t in_data_size, + int64_t in_data_size, auto flags) { - size_t data_size = 1; - size_t f_stride = 1; - size_t b_stride = 1; + int64_t data_size = 1; + int64_t f_stride = 1; + int64_t b_stride = 1; flags.row_contiguous = true; flags.col_contiguous = true; for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { @@ -240,7 +241,7 @@ void Split::eval( std::vector indices(1, 0); indices.insert(indices.end(), indices_.begin(), indices_.end()); - for (int i = 0; i < indices.size(); i++) { + for (int i = 0; i < std::ssize(indices); i++) { size_t offset = indices[i] * in.strides()[axis_]; auto [new_flags, data_size] = compute_new_flags( outputs[i].shape(), in.strides(), in.data_size(), in.flags()); @@ -254,7 +255,7 @@ void Squeeze::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; Strides strides; for (int i = 0, j = 0; i < in.ndim(); ++i) { - if (j < axes_.size() && i == axes_[j]) { + if (j < std::ssize(axes_) && i == axes_[j]) { j++; } else { strides.push_back(in.strides(i)); @@ -272,7 +273,7 @@ void Transpose::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); Strides out_strides(out.ndim()); auto& in = inputs[0]; - for (int ax = 0; ax < axes_.size(); ++ax) { + for (int ax = 0; ax < std::ssize(axes_); ++ax) { out_strides[ax] = in.strides()[axes_[ax]]; } diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 44e2a432bb..9864b4fe68 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -120,7 +120,7 @@ void compiled_allocate_outputs( Strides strides; size_t data_size; array::Flags flags; - for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) { auto& in = inputs[i]; // Conditions for donation // - Correct size @@ -138,7 +138,7 @@ void compiled_allocate_outputs( data_size = in.data_size(); } } - for (; o < outputs.size(); ++o) { + for (; o < std::ssize(outputs); ++o) { outputs[o].set_data( allocator::malloc(data_size * outputs[o].itemsize()), data_size, @@ -147,7 +147,7 @@ void compiled_allocate_outputs( } } else { int o = 0; - for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) { auto& in = inputs[i]; // Conditions for donation // - Row contiguous @@ -162,7 +162,7 @@ void compiled_allocate_outputs( o++; } } - for (; o < outputs.size(); ++o) { + for (; o < std::ssize(outputs); ++o) { outputs[o].set_data(allocator::malloc(outputs[o].nbytes())); } } @@ -193,7 +193,7 @@ std::tuple> compiled_collapse_contiguous_dims( // Broadcast the inputs to the output shape. Strides xstrides; - size_t j = 0; + int j = 0; for (; j < shape.size() - x.ndim(); ++j) { if (shape[j] == 1) { xstrides.push_back(out.strides()[j]); @@ -201,7 +201,7 @@ std::tuple> compiled_collapse_contiguous_dims( xstrides.push_back(0); } } - for (size_t i = 0; i < x.ndim(); ++i, ++j) { + for (int i = 0; i < x.ndim(); ++i, ++j) { if (x.shape(i) == 1) { if (shape[j] == 1) { xstrides.push_back(out.strides()[j]); @@ -224,13 +224,13 @@ bool compiled_use_large_index( const std::vector& outputs, bool contiguous) { if (contiguous) { - size_t max_size = 0; + int64_t max_size = 0; for (const auto& in : inputs) { max_size = std::max(max_size, in.data_size()); } return max_size > UINT32_MAX; } else { - size_t max_size = 0; + int64_t max_size = 0; for (const auto& o : outputs) { max_size = std::max(max_size, o.size()); } diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index ce41963de7..cb2fbacb8b 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -27,7 +27,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) { namespace mlx::core { -void Load::eval_cpu(const std::vector& inputs, array& out) { +void Load::eval_cpu(const std::vector& /* inputs */, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto read_task = [out_ptr = out.data(), size = out.size(), diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 1b6902ff33..0a4760fdce 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -183,7 +183,7 @@ inline auto check_contiguity(const Shape& shape, const Strides& strides) { } inline bool is_donatable(const array& in, const array& out) { - constexpr size_t donation_extra = 16384; + constexpr int64_t donation_extra = 16384; return in.is_donatable() && in.itemsize() == out.itemsize() && in.buffer_size() <= out.nbytes() + donation_extra; From 310e501e6a3860a92980372ba66719546c8a9a10 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 16:52:25 -0700 Subject: [PATCH 07/30] WIP (cpu) --- mlx/backend/cpu/arg_reduce.cpp | 8 ++++---- mlx/backend/cpu/binary.cpp | 8 ++++---- mlx/backend/cpu/binary_two.h | 8 ++++---- mlx/backend/cpu/conv.cpp | 14 +++++++------- mlx/backend/cpu/eig.cpp | 2 +- mlx/backend/cpu/eigh.cpp | 2 +- mlx/backend/cpu/encoder.h | 4 ++-- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index 66468912d1..41ab9fb609 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -17,14 +17,14 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) { Strides strides = remove_index(in.strides(), axis); Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); - auto out_ptr = out.data(); + auto out_ptr = out.data(); - for (uint32_t i = 0; i < out.size(); ++i) { + for (int64_t i = 0; i < out.size(); ++i) { auto loc = elem_to_loc(i, shape, strides); auto local_in_ptr = in_ptr + loc; - uint32_t ind_v = 0; + int64_t ind_v = 0; InT v = (*local_in_ptr); - for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { + for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { op(j, (*local_in_ptr), &ind_v, &v); } out_ptr[i] = ind_v; diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index 35aa2a3e0d..94dac14352 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -17,7 +17,7 @@ namespace mlx::core { namespace { template -void binary(const array& a, const array& b, array& out, Op op, Stream stream) { +void binary(const array& a, const array& b, array& out, Op /* op */, Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); @@ -81,7 +81,7 @@ void comparison_op( const array& a, const array& b, array& out, - Op op, + Op /* op */, Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); @@ -146,7 +146,7 @@ void binary_float( const array& a, const array& b, array& out, - Op op, + Op /* op */, Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); @@ -187,7 +187,7 @@ void binary_int( const array& a, const array& b, array& out, - Op op, + Op /* op */, Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); diff --git a/mlx/backend/cpu/binary_two.h b/mlx/backend/cpu/binary_two.h index fa0ca7996e..7038310e3f 100644 --- a/mlx/backend/cpu/binary_two.h +++ b/mlx/backend/cpu/binary_two.h @@ -99,7 +99,7 @@ void binary_op_dispatch_dims( ContiguousIterator a_it(shape, a_strides, ndim - 2); ContiguousIterator b_it(shape, b_strides, ndim - 2); auto stride = out_strides[ndim - 3]; - for (size_t elem = 0; elem < a.size(); elem += stride) { + for (int64_t elem = 0; elem < std::ssize(a); elem += stride) { binary_op_dims( a_ptr + a_it.loc, b_ptr + b_it.loc, @@ -137,21 +137,21 @@ void binary_op( if (bopt == BinaryOpType::ScalarScalar) { std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); } else if (bopt == BinaryOpType::ScalarVector) { - for (size_t i = 0; i < b.data_size(); ++i) { + for (int64_t i = 0; i < b.data_size(); ++i) { std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); out_a_ptr++; out_b_ptr++; b_ptr++; } } else if (bopt == BinaryOpType::VectorScalar) { - for (size_t i = 0; i < a.data_size(); ++i) { + for (int64_t i = 0; i < a.data_size(); ++i) { std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); out_a_ptr++; out_b_ptr++; a_ptr++; } } else { // VectorVector - for (size_t i = 0; i < a.size(); ++i) { + for (int64_t i = 0; i < a.size(); ++i) { std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); out_a_ptr++; out_b_ptr++; diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index 5d4638adeb..e278a2b7fa 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -860,7 +860,7 @@ void explicit_gemm_conv_1D_cpu( const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, - const std::vector& wt_dilation, + const std::vector& /* wt_dilation */, Stream stream) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = in.shape(1); // Input spatial dim @@ -1003,7 +1003,7 @@ void explicit_gemm_conv_ND_cpu( const std::vector& padding_lo, const std::vector& padding_hi, const std::vector& wt_strides, - const std::vector& wt_dilation, + const std::vector& /* wt_dilation */, const bool flip, Stream stream) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) @@ -1023,7 +1023,7 @@ void explicit_gemm_conv_ND_cpu( // Pad input Shape padded_shape(in.shape().size()); padded_shape.front() = N; - for (size_t i = 0; i < iDim.size(); i++) { + for (int i = 0; i < iDim.size(); i++) { padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i]; } padded_shape.back() = C; @@ -1054,20 +1054,20 @@ void explicit_gemm_conv_ND_cpu( // Make strided view Shape strided_shape(oDim.size() + wDim.size() + 2); strided_shape.front() = N; - for (size_t i = 0; i < oDim.size(); i++) { + for (int i = 0; i < oDim.size(); i++) { strided_shape[i + 1] = oDim[i]; } - for (size_t i = 0; i < wDim.size(); i++) { + for (int i = 0; i < wDim.size(); i++) { strided_shape[i + 1 + oDim.size()] = wDim[i]; } strided_shape.back() = C; Strides strided_strides(in.shape().size() * 2 - 2); strided_strides[0] = in_padded.strides()[0]; - for (size_t i = 0; i < wt_strides.size(); i++) { + for (int i = 0; i < std::ssize(wt_strides); i++) { strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i]; } - for (size_t i = 1; i < in_padded.strides().size(); i++) { + for (int i = 1; i < std::ssize(in_padded.strides()); i++) { strided_strides[i + wt_strides.size()] = in_padded.strides()[i]; } diff --git a/mlx/backend/cpu/eig.cpp b/mlx/backend/cpu/eig.cpp index 0d1f95a57d..ea94a324d3 100644 --- a/mlx/backend/cpu/eig.cpp +++ b/mlx/backend/cpu/eig.cpp @@ -70,7 +70,7 @@ void eig_impl( auto eig_tmp = static_cast(eig_tmp_data.buffer.raw_ptr()); auto vec_tmp = static_cast(vec_tmp_data.buffer.raw_ptr()); auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; - for (size_t i = 0; i < size / (N * N); ++i) { + for (int64_t i = 0; i < size / (N * N); ++i) { geev( &jobl, &jobr, diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index d457c1fd99..8b586656d0 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -165,7 +165,7 @@ void eigh_impl( EighWork work(jobz, uplo, N); // Work loop - for (size_t i = 0; i < size / (N * N); ++i) { + for (int64_t i = 0; i < size / (N * N); ++i) { work.run(vec_ptr, eig_ptr); vec_ptr += N * N; eig_ptr += N; diff --git a/mlx/backend/cpu/encoder.h b/mlx/backend/cpu/encoder.h index b8e33ca810..df65e53670 100644 --- a/mlx/backend/cpu/encoder.h +++ b/mlx/backend/cpu/encoder.h @@ -20,8 +20,8 @@ struct CommandEncoder { CommandEncoder(CommandEncoder&&) = delete; CommandEncoder& operator=(CommandEncoder&&) = delete; - void set_input_array(const array& a) {} - void set_output_array(array& a) {} + void set_input_array(const array& /* a */) {} + void set_output_array(array& /* a */) {} // Hold onto a temporary until any already scheduled tasks which use it as // an input are complete. From 63d91557e0a6266cc687e1e3a10842074156b9f8 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 17:05:48 -0700 Subject: [PATCH 08/30] fix FFT (PocketFFT requires size_t for axis) --- mlx/fft.cpp | 2 +- mlx/primitives.cpp | 2 +- mlx/primitives.h | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 33f5e763a6..69fcf64791 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -27,7 +27,7 @@ array fft_impl( return a; } - std::vector valid_axes; + std::vector valid_axes; for (int ax : axes) { valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax); } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 152487de63..0faa9f4074 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2156,7 +2156,7 @@ std::pair, std::vector> FFT::vmap( auto out_shape = in.shape(); if (ax >= 0) { for (auto& fft_ax : fft_axes) { - if (fft_ax >= ax) { + if (static_cast(fft_ax) >= ax) { fft_ax++; } if (real_) { diff --git a/mlx/primitives.h b/mlx/primitives.h index a37124db92..9ff82ca8ac 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1071,7 +1071,8 @@ class FFT : public UnaryPrimitive { public: explicit FFT( Stream stream, - const std::vector& axes, + // Note: PocketFFT requires size_t + const std::vector& axes, bool inverse, bool real) : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {} @@ -1089,7 +1090,7 @@ class FFT : public UnaryPrimitive { } private: - std::vector axes_; + std::vector axes_; bool inverse_; bool real_; }; From 76ef1e98f368f423e0652864a99b343fd8a729be Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 30 Oct 2025 16:18:59 -0700 Subject: [PATCH 09/30] WIP (common) --- mlx/backend/common/reduce.cpp | 4 ++-- mlx/backend/common/utils.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index ceef464007..aa4c735609 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -28,7 +28,7 @@ std::pair shapes_without_reduction_axes( ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // The data is all there and we are reducing over everything - if (x.size() == x.data_size() && axes.size() == x.ndim() && + if (x.size() == x.data_size() && std::ssize(axes) == x.ndim() && x.flags().contiguous) { return ContiguousAllReduce; } @@ -38,7 +38,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // Merge consecutive axes Shape shape = {x.shape(axes[0])}; Strides strides = {x.strides()[axes[0]]}; - for (int i = 1; i < axes.size(); i++) { + for (int i = 1; i < std::ssize(axes); i++) { if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { shape.back() *= x.shape(axes[i]); strides.back() = x.strides()[axes[i]]; diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index ae169e35e2..e2c11d6691 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -28,7 +28,7 @@ std::tuple> collapse_contiguous_dims( if (shape[0] != 1) { to_collapse.push_back(0); } - size_t size = shape[0]; + int64_t size = shape[0]; for (int i = 1; i < shape.size(); i++) { bool contiguous = true; size *= shape[i]; @@ -64,7 +64,7 @@ std::tuple> collapse_contiguous_dims( current_shape *= shape[to_collapse[k]]; } out_shape.push_back(current_shape); - for (int j = 0; j < strides.size(); j++) { + for (int j = 0; j < std::ssize(strides); j++) { const auto& st = strides[j]; out_strides[j].push_back(st[to_collapse[k - 1]]); } From 45a8b226af38f9d334a8b887c929168856212d2c Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 30 Oct 2025 16:24:51 -0700 Subject: [PATCH 10/30] WIP (cpu) --- mlx/backend/cpu/arange.h | 2 +- mlx/backend/cpu/binary.cpp | 7 +++++- mlx/backend/cpu/cholesky.cpp | 4 ++-- mlx/backend/cpu/gemm.h | 8 +++---- mlx/backend/cpu/gemms/bnns.cpp | 22 ++++++++--------- mlx/backend/cpu/gemms/cblas.cpp | 42 ++++++++++++++++----------------- mlx/backend/cpu/hadamard.cpp | 10 ++++---- mlx/backend/cpu/indexing.cpp | 30 +++++++++++------------ mlx/backend/cpu/masked_mm.cpp | 28 +++++++++++----------- mlx/backend/cpu/primitives.cpp | 22 ++++++++--------- mlx/backend/cpu/qrf.cpp | 10 ++++---- mlx/backend/cpu/simd/math.h | 7 +++--- mlx/backend/cpu/sort.cpp | 20 ++++++++-------- mlx/backend/cpu/svd.cpp | 12 +++++----- mlx/backend/cpu/ternary.h | 2 +- mlx/backend/cpu/unary.h | 10 ++++---- 16 files changed, 121 insertions(+), 115 deletions(-) diff --git a/mlx/backend/cpu/arange.h b/mlx/backend/cpu/arange.h index 9e9b03bd77..96b1e6e04a 100644 --- a/mlx/backend/cpu/arange.h +++ b/mlx/backend/cpu/arange.h @@ -10,7 +10,7 @@ namespace mlx::core { namespace { template -void arange(T start, T next, array& out, size_t size, Stream stream) { +void arange(T start, T next, array& out, int64_t size, Stream stream) { auto ptr = out.data(); auto step_size = next - start; auto& encoder = cpu::get_command_encoder(stream); diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index 94dac14352..d98b7332b7 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -17,7 +17,12 @@ namespace mlx::core { namespace { template -void binary(const array& a, const array& b, array& out, Op /* op */, Stream stream) { +void binary( + const array& a, + const array& b, + array& out, + Op /* op */, + Stream stream) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); diff --git a/mlx/backend/cpu/cholesky.cpp b/mlx/backend/cpu/cholesky.cpp index 3c5bbbc93d..2446423406 100644 --- a/mlx/backend/cpu/cholesky.cpp +++ b/mlx/backend/cpu/cholesky.cpp @@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) { N = a.shape(-1), size = a.size()]() mutable { char uplo = (upper) ? 'L' : 'U'; - size_t num_matrices = size / (N * N); - for (int i = 0; i < num_matrices; i++) { + int64_t num_matrices = size / (N * N); + for (int64_t i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. int info; potrf( diff --git a/mlx/backend/cpu/gemm.h b/mlx/backend/cpu/gemm.h index d665cb91ff..93aa3ee313 100644 --- a/mlx/backend/cpu/gemm.h +++ b/mlx/backend/cpu/gemm.h @@ -12,12 +12,12 @@ void matmul( T* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, diff --git a/mlx/backend/cpu/gemms/bnns.cpp b/mlx/backend/cpu/gemms/bnns.cpp index 2ec0fd4e23..545d79d4bf 100644 --- a/mlx/backend/cpu/gemms/bnns.cpp +++ b/mlx/backend/cpu/gemms/bnns.cpp @@ -34,7 +34,7 @@ void matmul_bnns( bool b_transposed, size_t lda, size_t ldb, - size_t ldc, + size_t /* ldc */, float alpha, float beta, size_t batch_size, @@ -52,7 +52,7 @@ void matmul_bnns( #pragma GCC diagnostic ignored "-Wdeprecated-declarations" if (beta != 1.0 && beta != 0.0) { // scale the output - for (auto i = 0; i < batch_size * M * N; ++i) { + for (size_t i = 0; i < batch_size * M * N; ++i) { out[i] *= beta; } beta = 1.0; @@ -127,7 +127,7 @@ void matmul_bnns( auto bnns_filter = BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); - for (int i = 0; i < batch_size; ++i) { + for (size_t i = 0; i < batch_size; ++i) { BNNSFilterApplyTwoInput( bnns_filter, reinterpret_cast( @@ -148,12 +148,12 @@ void matmul( float16_t* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, @@ -183,12 +183,12 @@ void matmul( bfloat16_t* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, diff --git a/mlx/backend/cpu/gemms/cblas.cpp b/mlx/backend/cpu/gemms/cblas.cpp index 765e9f539a..3277b7a784 100644 --- a/mlx/backend/cpu/gemms/cblas.cpp +++ b/mlx/backend/cpu/gemms/cblas.cpp @@ -13,20 +13,20 @@ void matmul( float* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); - size_t M = a_shape[ndim - 2]; - size_t N = b_shape[ndim - 1]; - size_t K = a_shape[ndim - 1]; + int64_t M = a_shape[ndim - 2]; + int64_t N = b_shape[ndim - 1]; + int64_t K = a_shape[ndim - 1]; for (int i = 0; i < batch_size; ++i) { cblas_sgemm( @@ -54,20 +54,20 @@ void matmul( double* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); - size_t M = a_shape[ndim - 2]; - size_t N = b_shape[ndim - 1]; - size_t K = a_shape[ndim - 1]; + int64_t M = a_shape[ndim - 2]; + int64_t N = b_shape[ndim - 1]; + int64_t K = a_shape[ndim - 1]; for (int i = 0; i < batch_size; ++i) { cblas_dgemm( @@ -95,20 +95,20 @@ void matmul( complex64_t* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, - size_t ldc, + int64_t lda, + int64_t ldb, + int64_t ldc, float alpha, float beta, - size_t batch_size, + int64_t batch_size, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides) { auto ndim = a_shape.size(); - size_t M = a_shape[ndim - 2]; - size_t N = b_shape[ndim - 1]; - size_t K = a_shape[ndim - 1]; + int64_t M = a_shape[ndim - 2]; + int64_t N = b_shape[ndim - 1]; + int64_t K = a_shape[ndim - 1]; auto calpha = static_cast(alpha); auto cbeta = static_cast(beta); diff --git a/mlx/backend/cpu/hadamard.cpp b/mlx/backend/cpu/hadamard.cpp index bf7e1dc261..aa1b164bb0 100644 --- a/mlx/backend/cpu/hadamard.cpp +++ b/mlx/backend/cpu/hadamard.cpp @@ -11,9 +11,9 @@ namespace mlx::core { // n = 2^k component template -void hadamard_n(T* out, int n, int m, float scale, size_t size) { +void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) { for (int b = 0; b < size / n; b++) { - size_t loc = b * n; + int64_t loc = b * n; T* data_ptr = out + loc; int h = 1; int n_over_2 = n / 2; @@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) { // m component template -void hadamard_m(T* out, int n, int m, float scale, size_t size) { +void hadamard_m(T* out, int n, int m, float scale, int64_t size) { auto h_matrices = hadamard_matrices(); auto& matrix = h_matrices[m]; auto start = 1; @@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) { std::vector hmat_vec; while (end != std::string_view::npos) { auto row = matrix.substr(start, end - start); - for (int i = 0; i < row.length(); i++) { + for (int i = 0; i < std::ssize(row); i++) { hmat_vec.push_back(row[i] == '+'); } start = end + 1; @@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) { } for (int b = 0; b < size / m / n; b++) { - size_t loc = b * n * m; + int64_t loc = b * n * m; T* data_ptr = out + loc; for (int i = 0; i < n; i++) { std::vector out(m); diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 6daced6fae..b743550fd8 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -78,7 +78,7 @@ void gather( can_copy = true; // Ignore leading 1s - int i = 0; + int64_t i = 0; for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i) ; @@ -91,7 +91,7 @@ void gather( can_copy = true; // Ignore trailing 1s - int i = slice_sizes.size() - 1; + int64_t i = slice_sizes.size() - 1; for (; i >= 0 && slice_sizes[i] == 1; --i) ; @@ -101,11 +101,11 @@ void gather( can_copy = (src.shape(i) == slice_sizes[i]); } } - size_t slice_size = 1; + int64_t slice_size = 1; for (auto s : slice_sizes) { slice_size *= s; } - size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; + int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; const T* src_ptr = src.data(); T* dst_ptr = out.data(); @@ -115,10 +115,10 @@ void gather( src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); } - size_t out_idx = 0; - for (int idx = 0; idx < ind_size; idx++) { - size_t src_idx = 0; - for (int ii = 0; ii < inds.size(); ++ii) { + int64_t out_idx = 0; + for (int64_t idx = 0; idx < ind_size; idx++) { + int64_t src_idx = 0; + for (int ii = 0; ii < std::ssize(inds); ++ii) { auto ax = axes[ii]; auto idx_loc = its[ii].loc; its[ii].step(); @@ -134,7 +134,7 @@ void gather( src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx); out_idx += slice_size; } else { - for (int jj = 0; jj < slice_size; jj++) { + for (int64_t jj = 0; jj < slice_size; jj++) { dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; src_it.step(); } @@ -403,11 +403,11 @@ void scatter( const std::vector& axes) { int nind = inds.size(); auto inds_ndim = updates.ndim() - out.ndim(); - size_t n_updates = nind ? inds[0].size() : 1; + int64_t n_updates = nind ? inds[0].size() : 1; Shape update_shape( updates.shape().begin() + inds_ndim, updates.shape().end()); - size_t update_size = 1; + int64_t update_size = 1; for (auto us : update_shape) { update_size *= us; } @@ -418,9 +418,9 @@ void scatter( auto out_ptr = out.data(); auto upd_ptr = updates.data(); - for (int i = 0; i < n_updates; ++i) { - size_t out_offset = 0; - for (int j = 0; j < inds.size(); ++j) { + for (int64_t i = 0; i < n_updates; ++i) { + int64_t out_offset = 0; + for (int j = 0; j < std::ssize(inds); ++j) { auto ax = axes[j]; auto idx_loc = its[j].loc; its[j].step(); @@ -429,7 +429,7 @@ void scatter( out_offset += (idx_val * out.strides()[ax]); } update_it.seek(i * update_size); - for (int j = 0; j < update_size; ++j) { + for (int64_t j = 0; j < update_size; ++j) { OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); update_it.step(); out_it.step(); diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 688479c602..81012a84de 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -25,7 +25,7 @@ inline void mask_matrix( const int64_t Y_data_str, const int64_t X_mask_str, const int64_t Y_mask_str, - const size_t mask_offset) { + const int64_t mask_offset) { int tX = (X + block_size - 1) / block_size; int tY = (Y + block_size - 1) / block_size; @@ -61,13 +61,13 @@ inline void segmented_mm( T* out, bool a_transposed, bool b_transposed, - size_t lda, - size_t ldb, + int64_t lda, + int64_t ldb, const Shape& a_shape, const Strides& a_strides, const Shape& b_shape, const Strides& b_strides, - size_t num_segments, + int64_t num_segments, const Shape& segments_shape, const Strides& segments_strides) { int ndim = a_shape.size(); @@ -149,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { auto [b_transposed, ldb, b, b_copied] = check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_); - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); + int64_t M = a.shape(-2); + int64_t N = b.shape(-1); + int64_t K = a.shape(-1); if (M == 0 || N == 0) { return; @@ -172,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { int batch_idx, int X, int Y, - size_t X_data_str, - size_t Y_data_str, + int64_t X_data_str, + int64_t Y_data_str, const Shape& mask_shape, const Strides& mask_strides, bool is_bool) { @@ -253,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { auto a_ptr = a.data(); auto b_ptr = b.data(); auto out_ptr = out.data(); - size_t num_matrices = out.size() / (M * size_t(N)); + int64_t num_matrices = out.size() / (M * int64_t(N)); auto ldc = out.shape(-1); encoder.dispatch([a_ptr, @@ -394,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { auto [a_transposed, lda, a] = check_transpose(a_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre); - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); + int64_t M = a.shape(-2); + int64_t N = b.shape(-1); + int64_t K = a.shape(-1); if (M == 0 || N == 0) { return; @@ -413,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { // Get batch dims auto batch_size_out = out.size() / (M * N); - size_t matrix_stride_out = M * N; + int64_t matrix_stride_out = M * N; auto get_batch_dims = [](const auto& v) { return decltype(v){v.begin(), v.end() - 2}; diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index f2cb12fdd5..18db2b3ddd 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -48,7 +48,7 @@ static std::pair compute_dynamic_offset( auto compute_offset = [strides, axes, offset = offset.data()](const auto* indices) { int64_t offset_ = 0; - for (int i = 0; i < axes.size(); ++i) { + for (int i = 0; i < std::ssize(axes); ++i) { offset_ += indices[i] * strides[axes[i]]; } offset[0] = offset_; @@ -193,9 +193,9 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; - for (int i = 0; i < inputs.size(); i++) { + for (int i = 0; i < std::ssize(inputs); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); - size_t data_offset = strides[axis_] * sizes[i]; + int64_t data_offset = strides[axis_] * sizes[i]; out_slice.copy_shared_buffer( out, strides, flags, out_slice.size(), data_offset); copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); @@ -205,7 +205,7 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { void Contiguous::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - constexpr size_t extra_bytes = 16384; + constexpr int64_t extra_bytes = 16384; if (in.buffer_size() <= out.nbytes() + extra_bytes && (in.flags().row_contiguous || (allow_col_major_ && in.flags().col_contiguous))) { @@ -254,8 +254,8 @@ void Pad::eval_cpu(const std::vector& inputs, array& out) { copy_cpu(val, out, CopyType::Scalar, stream()); // Find offset for start of input values - size_t data_offset = 0; - for (int i = 0; i < axes_.size(); i++) { + int64_t data_offset = 0; + for (int i = 0; i < std::ssize(axes_); i++) { auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i]; data_offset += out.strides()[ax] * low_pad_size_[i]; } @@ -274,10 +274,10 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { // keys has shape (N1, ..., NK, 2) // out has shape (N1, ..., NK, M1, M2, ...) auto& keys = inputs[0]; - size_t num_keys = keys.size() / 2; + int64_t num_keys = keys.size() / 2; - size_t elems_per_key = out.size() / num_keys; - size_t bytes_per_key = out.itemsize() * elems_per_key; + int64_t elems_per_key = out.size() / num_keys; + int64_t bytes_per_key = out.itemsize() * elems_per_key; out.set_data(allocator::malloc(out.nbytes())); auto kptr = inputs[0].data(); @@ -291,8 +291,8 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { num_keys, kshape = keys.shape(), kstrides = keys.strides()]() mutable { - size_t out_skip = (bytes_per_key + 4 - 1) / 4; - auto half_size = out_skip / 2; + int64_t out_skip = (bytes_per_key + 4 - 1) / 4; + uintptr_t half_size = out_skip / 2; bool even = out_skip % 2 == 0; for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) { auto ptr = reinterpret_cast(cptr); diff --git a/mlx/backend/cpu/qrf.cpp b/mlx/backend/cpu/qrf.cpp index 13c7e11321..d3d6717e81 100644 --- a/mlx/backend/cpu/qrf.cpp +++ b/mlx/backend/cpu/qrf.cpp @@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { const int M = a.shape(-2); const int N = a.shape(-1); const int lda = M; - size_t num_matrices = a.size() / (M * N); + int64_t num_matrices = a.size() / (M * N); // Copy A to inplace input and make it col-contiguous array in(a.shape(), a.dtype(), nullptr, {}); @@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { auto work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices - for (int i = 0; i < num_matrices; ++i) { + for (int64_t i = 0; i < num_matrices; ++i) { // Solve geqrf( &M, @@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { } allocator::free(work); - for (int i = 0; i < num_matrices; ++i) { + for (int64_t i = 0; i < num_matrices; ++i) { /// num_reflectors x N for (int j = 0; j < num_reflectors; ++j) { for (int k = 0; k < j; ++k) { @@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices - for (int i = 0; i < num_matrices; ++i) { + for (int64_t i = 0; i < num_matrices; ++i) { // Compute Q orgqr( &M, @@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { &info); } - for (int i = 0; i < num_matrices; ++i) { + for (int64_t i = 0; i < num_matrices; ++i) { // M x num_reflectors for (int j = 0; j < M; ++j) { for (int k = 0; k < num_reflectors; ++k) { diff --git a/mlx/backend/cpu/simd/math.h b/mlx/backend/cpu/simd/math.h index f9fc8317a5..9854f7e91d 100644 --- a/mlx/backend/cpu/simd/math.h +++ b/mlx/backend/cpu/simd/math.h @@ -79,7 +79,8 @@ Simd sincos(Simd in) { // Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4 // and another one for Pi/4(2)) != static_cast(0); // The magic pass: "Extended precision modular arithmetic" // x = ((x - y * DP1) - y * DP2) - y * DP3 @@ -87,8 +88,8 @@ Simd sincos(Simd in) { x = fma(y, Simd(-2.4187564849853515625e-4f), x); x = fma(y, Simd(-3.77489497744594108e-8f), x); - sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0); - auto sign_mask_cos = ((emm2 - 2) & 4) != 0; + sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast(0)); + auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast(0); // Evaluate the first polynom (0 <= x <= Pi/4) in y1, // and the second polynom (Pi/4 <= x <= 0) in y2 diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index fcf12d7ad7..8e05951aaa 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -120,8 +120,8 @@ template void sort(array& out, int axis) { // Get axis, shape and stride info axis = axis < 0 ? axis + out.ndim() : axis; - size_t in_size = out.size(); - size_t n_rows = in_size / out.shape(axis); + int64_t in_size = out.size(); + int64_t n_rows = in_size / out.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); @@ -136,7 +136,7 @@ void sort(array& out, int axis) { ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); auto out_ptr = out.data(); - for (int i = 0; i < n_rows; i++) { + for (int64_t i = 0; i < n_rows; i++) { T* data_ptr = out_ptr + src_it.loc; StridedIterator st(data_ptr, axis_stride, 0); @@ -151,7 +151,7 @@ template void argsort(const array& in, array& out, int axis) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; - size_t n_rows = in.size() / in.shape(axis); + int64_t n_rows = in.size() / in.shape(axis); auto in_remaining_shape = in.shape(); in_remaining_shape.erase(in_remaining_shape.begin() + axis); @@ -176,7 +176,7 @@ void argsort(const array& in, array& out, int axis) { out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); auto in_ptr = in.data(); auto out_ptr = out.data(); - for (int i = 0; i < n_rows; i++) { + for (int64_t i = 0; i < n_rows; i++) { const T* data_ptr = in_ptr + in_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc; @@ -214,8 +214,8 @@ template void partition(array& out, int axis, int kth) { // Get axis, shape and stride info axis = axis < 0 ? axis + out.ndim() : axis; - size_t in_size = out.size(); - size_t n_rows = in_size / out.shape(axis); + int64_t in_size = out.size(); + int64_t n_rows = in_size / out.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); @@ -232,7 +232,7 @@ void partition(array& out, int axis, int kth) { ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); auto out_ptr = out.data(); - for (int i = 0; i < n_rows; i++) { + for (int64_t i = 0; i < n_rows; i++) { T* data_ptr = out_ptr + src_it.loc; src_it.step(); @@ -248,7 +248,7 @@ template void argpartition(const array& in, array& out, int axis, int kth) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; - size_t n_rows = in.size() / in.shape(axis); + int64_t n_rows = in.size() / in.shape(axis); auto in_remaining_shape = in.shape(); in_remaining_shape.erase(in_remaining_shape.begin() + axis); @@ -277,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) { auto in_ptr = in.data(); auto out_ptr = out.data(); - for (int i = 0; i < n_rows; i++) { + for (int64_t i = 0; i < n_rows; i++) { const T* data_ptr = in_ptr + in_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc; in_it.step(); diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 1fc94c382c..54d15fabc5 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -27,7 +27,7 @@ void svd_impl( const int N = a.shape(-1); const int K = std::min(M, N); - size_t num_matrices = a.size() / (M * N); + int64_t num_matrices = a.size() / (M * N); // lapack clobbers the input, so we have to make a copy. array in(a.shape(), a.dtype(), nullptr, {}); @@ -121,7 +121,7 @@ void svd_impl( auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; // Loop over matrices. - for (int i = 0; i < num_matrices; i++) { + for (int64_t i = 0; i < num_matrices; i++) { gesdd( /* jobz = */ jobz, // M and N are swapped since lapack expects column-major. @@ -153,10 +153,10 @@ void svd_impl( template void compute_svd( - const array& a, - bool compute_uv, - std::vector& outputs, - Stream stream) {} + const array& /* a */, + bool /* compute_uv */, + std::vector& /* outputs */, + Stream /* stream */) {} void SVD::eval_cpu( const std::vector& inputs, diff --git a/mlx/backend/cpu/ternary.h b/mlx/backend/cpu/ternary.h index a27a7f2a9f..4674d9fef5 100644 --- a/mlx/backend/cpu/ternary.h +++ b/mlx/backend/cpu/ternary.h @@ -136,7 +136,7 @@ void ternary_op( if (topt == TernaryOpType::ScalarScalarScalar) { *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); } else if (topt == TernaryOpType::VectorVectorVector) { - for (size_t i = 0; i < out.size(); ++i) { + for (int64_t i = 0; i < out.size(); ++i) { *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); a_ptr++; b_ptr++; diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index 14c1dd479a..8a4c64e69e 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -10,8 +10,8 @@ namespace mlx::core { template -void unary_op(const T* a, U* out, size_t shape, size_t stride) { - for (size_t i = 0; i < shape; i += 1) { +void unary_op(const T* a, U* out, int64_t shape, int64_t stride) { + for (int64_t i = 0; i < shape; i += 1) { out[i] = Op{}(*a); a += stride; } @@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) { src++; } } else { - size_t shape = ndim > 0 ? a.shape().back() : 1; - size_t stride = ndim > 0 ? a.strides().back() : 1; + int64_t shape = ndim > 0 ? a.shape().back() : 1; + int64_t stride = ndim > 0 ? a.strides().back() : 1; if (ndim <= 1) { unary_op(src, dst, shape, stride); return; } auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1); - for (size_t elem = 0; elem < a.size(); elem += shape) { + for (int64_t elem = 0; elem < a.size(); elem += shape) { unary_op(src + it.loc, dst + elem, shape, stride); it.step(); } From a1212b4e44515af9eee049823a50e4d189914fd5 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 30 Oct 2025 16:25:11 -0700 Subject: [PATCH 11/30] WIP (distributed) --- mlx/distributed/primitives.cpp | 8 +-- mlx/distributed/ring/ring.cpp | 121 +++++++++++++++++---------------- 2 files changed, 67 insertions(+), 62 deletions(-) diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 5e8d5327a1..0c87172be7 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -27,7 +27,7 @@ std::pair, std::vector> AllReduce::vmap( } std::vector AllReduce::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, const std::vector&) { switch (reduce_type_) { @@ -44,10 +44,10 @@ std::vector AllReduce::jvp( } std::vector AllReduce::vjp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& cotangents, const std::vector&, - const std::vector& outputs) { + const std::vector& /* outputs */) { return cotangents; } @@ -58,7 +58,7 @@ std::pair, std::vector> AllGather::vmap( } std::vector AllGather::jvp( - const std::vector& primals, + const std::vector& /* primals */, const std::vector& tangents, const std::vector&) { return {all_gather(tangents[0], group(), stream())}; diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index ac55ea30b0..702cf7a4c6 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -90,8 +90,8 @@ namespace mlx::core::distributed::ring { -constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024; -constexpr const size_t ALL_SUM_BUFFERS = 2; +constexpr const int64_t ALL_SUM_SIZE = 8 * 1024 * 1024; +constexpr const int64_t ALL_SUM_BUFFERS = 2; constexpr const int CONN_ATTEMPTS = 5; constexpr const int CONN_WAIT = 1000; @@ -141,27 +141,27 @@ class SocketThread { } template - std::future send(const T* buffer, size_t size) { + std::future send(const T* buffer, int64_t size) { return send_impl(reinterpret_cast(buffer), size * sizeof(T)); } template - std::future recv(T* buffer, size_t size) { + std::future recv(T* buffer, int64_t size) { return recv_impl(reinterpret_cast(buffer), size * sizeof(T)); } private: struct SocketTask { - SocketTask(void* b, size_t s, std::promise&& p) + SocketTask(void* b, int64_t s, std::promise&& p) : buffer(b), size(s), promise(std::move(p)) {} SocketTask(SocketTask&& t) : buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {} void* buffer; - size_t size; + int64_t size; std::promise promise; }; - std::future send_impl(const char* buffer, size_t size) { + std::future send_impl(const char* buffer, int64_t size) { std::promise send_completed_promise; auto send_completed_future = send_completed_promise.get_future(); if (size == 0) { @@ -178,7 +178,7 @@ class SocketThread { return send_completed_future; } - std::future recv_impl(char* buffer, size_t size) { + std::future recv_impl(char* buffer, int64_t size) { std::promise recv_completed_promise; auto recv_completed_future = recv_completed_promise.get_future(); if (size == 0) { @@ -232,7 +232,7 @@ class SocketThread { if (!recvs_.empty()) { auto& task = recvs_.front(); - ssize_t r = ::recv(fd_, task.buffer, task.size, 0); + int64_t r = ::recv(fd_, task.buffer, task.size, 0); if (r > 0) { task.buffer = static_cast(task.buffer) + r; task.size -= r; @@ -246,7 +246,7 @@ class SocketThread { } if (!sends_.empty()) { auto& task = sends_.front(); - ssize_t r = ::send(fd_, task.buffer, task.size, 0); + int64_t r = ::send(fd_, task.buffer, task.size, 0); if (r > 0) { task.buffer = static_cast(task.buffer) + r; task.size -= r; @@ -283,12 +283,12 @@ class CommunicationThreads { } template - std::future send(int socket, T* buffer, size_t size) { + std::future send(int socket, T* buffer, int64_t size) { return threads_.at(socket).send(buffer, size); } template - std::future recv(int socket, T* buffer, size_t size) { + std::future recv(int socket, T* buffer, int64_t size) { return threads_.at(socket).recv(buffer, size); } @@ -505,7 +505,7 @@ std::vector make_connections( } template struct SumOp { - void operator()(const T* input, T* output, size_t N) { + void operator()(const T* input, T* output, int64_t N) { while (N-- > 0) { *output += *input; input++; @@ -516,7 +516,7 @@ struct SumOp { template struct MaxOp { - void operator()(const T* input, T* output, size_t N) { + void operator()(const T* input, T* output, int64_t N) { while (N-- > 0) { *output = std::max(*output, *input); input++; @@ -527,7 +527,7 @@ struct MaxOp { template struct MinOp { - void operator()(const T* input, T* output, size_t N) { + void operator()(const T* input, T* output, int64_t N) { while (N-- > 0) { *output = std::min(*output, *input); input++; @@ -542,7 +542,7 @@ class RingGroup : public GroupImpl { public: RingGroup(int rank, std::vector> nodes, bool verbose) : rank_(rank), verbose_(verbose), pool_(0) { - if (rank_ > 0 && rank_ >= nodes.size()) { + if (rank_ > 0 && rank_ >= std::ssize(nodes)) { throw std::runtime_error( "[ring] Rank cannot be larger than the size of the group"); } @@ -589,7 +589,7 @@ class RingGroup : public GroupImpl { // Configure all sockets to use TCP no delay. int one = 1; - for (int i = 0; i < sockets_right_.size(); i++) { + for (int64_t i = 0; i < std::ssize(sockets_right_); i++) { setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); } @@ -646,7 +646,8 @@ class RingGroup : public GroupImpl { output, all_reduce>(input, output, stream, MinOp())); } - std::shared_ptr split(int color, int key = -1) override { + std::shared_ptr split(int /* color */, int /* key */ = -1) + override { throw std::runtime_error("[ring] Group split not supported."); } @@ -658,15 +659,15 @@ class RingGroup : public GroupImpl { nbytes = input.nbytes(), output_ptr = output.data(), this]() { - constexpr size_t min_send_size = 262144; - size_t n_gathers = std::max( - std::min( + constexpr int64_t min_send_size = 262144; + int64_t n_gathers = std::max( + std::min( sockets_right_.size() + sockets_left_.size(), nbytes / min_send_size), - size_t(1)); - size_t bytes_per_gather = ceildiv(nbytes, n_gathers); + 1); + int64_t bytes_per_gather = ceildiv(nbytes, n_gathers); std::vector> all_gathers; - for (int i = 0; i < n_gathers; i++) { + for (int64_t i = 0; i < n_gathers; i++) { auto offset = i * bytes_per_gather; all_gathers.emplace_back(pool_.enqueue(std::bind( &RingGroup::all_gather_impl, @@ -742,10 +743,14 @@ class RingGroup : public GroupImpl { auto out_ptr = output.data(); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(output); - encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() { + encoder.dispatch([in_ptr, + out_ptr, + size = static_cast(input.size()), + this, + reduce_op]() { // If the input data cannot be split into size_ segments then copy it and // all reduce a local buffer prefilled with 0s. - size_t nbytes = size * sizeof(T); + int64_t nbytes = size * sizeof(T); if (size < size_) { // TODO: Maybe allocate dynamically so we don't have the constraint // below? @@ -778,16 +783,16 @@ class RingGroup : public GroupImpl { // Split the all reduces so that each member has at least 1 buffer to // send/recv per segment. - constexpr size_t min_send_size = 262144; - size_t n_reduces = std::max( - std::min( + constexpr int64_t min_send_size = 262144; + int64_t n_reduces = std::max( + std::min( sockets_right_.size() + sockets_left_.size(), nbytes / (size_ * min_send_size)), - size_t(1)); - size_t step = ceildiv(size, n_reduces); + 1); + int64_t step = ceildiv(size, n_reduces); std::vector> all_sums; - for (int i = 0; i < n_reduces; i++) { + for (int64_t i = 0; i < n_reduces; i++) { all_sums.emplace_back(pool_.enqueue(std::bind( &RingGroup::all_reduce_impl, this, @@ -810,7 +815,7 @@ class RingGroup : public GroupImpl { void all_reduce_impl( T* buffer, T* data, - size_t data_size, + int64_t data_size, int socket_right, int socket_left, int direction, @@ -821,10 +826,10 @@ class RingGroup : public GroupImpl { // We split the data into `size_` segments of size `segment_size` and each // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets. - size_t segment_size = ceildiv(data_size, size_); - size_t BUFFER_SIZE = std::max( - size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); - size_t n_packets = ceildiv(segment_size, BUFFER_SIZE); + int64_t segment_size = ceildiv(data_size, size_); + int64_t BUFFER_SIZE = std::max( + 32768, std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2)); + int64_t n_packets = ceildiv(segment_size, BUFFER_SIZE); // Initial segments int send_segment = rank_; @@ -833,21 +838,21 @@ class RingGroup : public GroupImpl { // Plan the whole reduce in terms of sends and recvs as indices in data. // It makes the actual async send and recv a bit simpler to follow when // there are less offset calculations around. - std::vector> send_plan; - std::vector> recv_plan; + std::vector> send_plan; + std::vector> recv_plan; // Two times the same send/recv operations, first scatter reduce and then // gather. for (int k = 0; k < 2; k++) { for (int i = 0; i < size_ - 1; i++) { - size_t send_start = send_segment * segment_size; - size_t send_stop = + int64_t send_start = send_segment * segment_size; + int64_t send_stop = std::min((send_segment + 1) * segment_size, data_size); - size_t recv_start = recv_segment * segment_size; - size_t recv_stop = + int64_t recv_start = recv_segment * segment_size; + int64_t recv_stop = std::min((recv_segment + 1) * segment_size, data_size); - for (size_t j = 0; j < n_packets; j++) { + for (int64_t j = 0; j < n_packets; j++) { send_plan.emplace_back( std::min(send_start + j * BUFFER_SIZE, send_stop), std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop)); @@ -864,18 +869,18 @@ class RingGroup : public GroupImpl { // Running the plan is fairly simple, we keep a send and a recv in flight // while doing the summation. T* recv_buffers[ALL_SUM_BUFFERS]; - for (int i = 0; i < ALL_SUM_BUFFERS; i++) { + for (int64_t i = 0; i < ALL_SUM_BUFFERS; i++) { recv_buffers[i] = buffer + i * BUFFER_SIZE; } std::future sends[2], recvs[2]; int a = 0; int b = (n_packets > 1) ? 1 : 0; - for (int i = 0, j = -b; i < send_plan.size(); j++, i++) { + for (int i = 0, j = -b; i < std::ssize(send_plan); j++, i++) { sends[a] = comm_.send( socket_send, data + send_plan[i].first, send_plan[i].second - send_plan[i].first); - if (2 * i < send_plan.size()) { + if (2 * i < std::ssize(send_plan)) { recvs[a] = comm_.recv( socket_recv, recv_buffers[i % ALL_SUM_BUFFERS], @@ -890,7 +895,7 @@ class RingGroup : public GroupImpl { if (j >= 0) { sends[b].wait(); recvs[b].wait(); - if (2 * j < send_plan.size()) { + if (2 * j < std::ssize(send_plan)) { reduce_op( recv_buffers[j % ALL_SUM_BUFFERS], data + recv_plan[j].first, @@ -907,8 +912,8 @@ class RingGroup : public GroupImpl { void all_gather_impl( const char* input, char* output, - size_t input_size, - size_t data_size, + int64_t input_size, + int64_t data_size, int socket_right, int socket_left, int direction) { @@ -941,11 +946,11 @@ class RingGroup : public GroupImpl { } void - send(const std::vector& sockets, const char* data, size_t data_size) { - size_t segment_size = - std::max(size_t(1024), ceildiv(data_size, sockets.size())); + send(const std::vector& sockets, const char* data, int64_t data_size) { + int64_t segment_size = + std::max(1024, ceildiv(data_size, std::ssize(sockets))); std::vector> sends; - for (int i = 0; i < sockets.size(); i++) { + for (int i = 0; i < std::ssize(sockets); i++) { if (i * segment_size >= data_size) { break; } @@ -959,11 +964,11 @@ class RingGroup : public GroupImpl { } } - void recv(const std::vector& sockets, char* data, size_t data_size) { - size_t segment_size = - std::max(size_t(1024), ceildiv(data_size, sockets.size())); + void recv(const std::vector& sockets, char* data, int64_t data_size) { + int64_t segment_size = + std::max(1024, ceildiv(data_size, std::ssize(sockets))); std::vector> recvs; - for (int i = 0; i < sockets.size(); i++) { + for (int i = 0; i < std::ssize(sockets); i++) { if (i * segment_size >= data_size) { break; } From 1bac0db7e303aad915da159f65b9793b21627a47 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 30 Oct 2025 16:25:36 -0700 Subject: [PATCH 12/30] WIP --- mlx/dtype.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp index 429fa42020..86a5da7c4b 100644 --- a/mlx/dtype.cpp +++ b/mlx/dtype.cpp @@ -166,7 +166,7 @@ bool issubdtype(const Dtype& a, const Dtype& b) { return a == b; } -bool issubdtype(const Dtype::Category& cat, const Dtype& type) { +bool issubdtype(const Dtype::Category& /* cat */, const Dtype& /* type */) { return false; } From 5baa361779e70a35af2033d14db18a0fc792a15a Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 09:39:38 -0700 Subject: [PATCH 13/30] WIP (tests) --- tests/autograd_tests.cpp | 4 ++-- tests/export_import_tests.cpp | 2 +- tests/load_tests.cpp | 4 ++-- tests/ops_tests.cpp | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 3a373fb18b..352e379f45 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -145,7 +145,7 @@ TEST_CASE("test jvp") { // No dependence between input and output { - auto fun = [](array in) { return array({1.0, 1.0}); }; + auto fun = [](array /* in */) { return array({1.0, 1.0}); }; auto out = jvp(fun, array(1.0f), array(1.0f)).second; CHECK(array_equal(out, zeros({2})).item()); } @@ -195,7 +195,7 @@ TEST_CASE("test vjp") { // No dependence between input and output { - auto fun = [](array in) { return array(1.); }; + auto fun = [](array /* in */) { return array(1.); }; auto out = vjp(fun, zeros({2}), array(1.)).second; CHECK(array_equal(out, zeros({2})).item()); } diff --git a/tests/export_import_tests.cpp b/tests/export_import_tests.cpp index 7ad2c640d8..060378f82f 100644 --- a/tests/export_import_tests.cpp +++ b/tests/export_import_tests.cpp @@ -44,7 +44,7 @@ TEST_CASE("test export basic functions") { } TEST_CASE("test export function with no inputs") { - auto fun = [](std::vector x) -> std::vector { + auto fun = [](std::vector /* x */) -> std::vector { return {zeros({2, 2})}; }; diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 1531ce060c..1eb01c204c 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -168,7 +168,7 @@ TEST_CASE("test gguf metadata") { CHECK_EQ(loaded_metadata.count("meta"), 1); auto& strs = std::get>(loaded_metadata["meta"]); CHECK_EQ(strs.size(), 3); - for (int i = 0; i < strs.size(); ++i) { + for (int i = 0; i < std::ssize(strs); ++i) { CHECK_EQ(strs[i], data[i]); } } @@ -187,7 +187,7 @@ TEST_CASE("test gguf metadata") { CHECK_EQ(loaded_metadata.size(), 4); auto& strs = std::get>(loaded_metadata["meta1"]); CHECK_EQ(strs.size(), 3); - for (int i = 0; i < strs.size(); ++i) { + for (int i = 0; i < std::ssize(strs); ++i) { CHECK_EQ(strs[i], data[i]); } auto& arr = std::get(loaded_metadata["meta2"]); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 2e8bbd692f..ac3c5e5ffa 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1668,7 +1668,7 @@ TEST_CASE("test error functions") { -0.1124629160182849, -0.5204998778130465, -0.7969082124228322}; - for (int i = 0; i < vals.size(); ++i) { + for (int i = 0; i < std::ssize(vals); ++i) { x = array(vals.begin()[i]); CHECK_EQ(erf(x).item(), doctest::Approx(expected.begin()[i])); } @@ -1686,7 +1686,7 @@ TEST_CASE("test error functions") { -0.08885599049425769, -0.4769362762044699, -1.1630871536766743}; - for (int i = 0; i < vals.size(); ++i) { + for (int i = 0; i < std::ssize(vals); ++i) { x = array(vals.begin()[i]); CHECK_EQ(erfinv(x).item(), doctest::Approx(expected.begin()[i])); } From 5a306d3495b71510c4dd66abea6f71578a27afd3 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 09:40:13 -0700 Subject: [PATCH 14/30] WIP (common) --- mlx/backend/common/slicing.cpp | 6 +++--- mlx/backend/common/utils.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 6f5736d637..14dad31c60 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -24,8 +24,8 @@ std::tuple prepare_slice( void shared_buffer_slice( const array& in, const Strides& out_strides, - size_t data_offset, - size_t data_size, + int64_t data_offset, + int64_t data_size, array& out) { // Compute row/col contiguity auto [no_bsx_size, is_row_contiguous, is_col_contiguous] = @@ -61,7 +61,7 @@ void slice( if (data_end < 0) { data_end += in.data_size(); } - size_t data_size = (data_end - data_offset); + int64_t data_size = (data_end - data_offset); shared_buffer_slice(in, inp_strides, data_offset, data_size, out); } diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 0a4760fdce..c68924f813 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -162,7 +162,7 @@ struct ContiguousIterator { }; inline auto check_contiguity(const Shape& shape, const Strides& strides) { - size_t no_broadcast_data_size = 1; + int64_t no_broadcast_data_size = 1; int64_t f_stride = 1; int64_t b_stride = 1; bool is_row_contiguous = true; From 981d2fdaf013d3a1ba13df28f455eda6e7637357 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 09:40:50 -0700 Subject: [PATCH 15/30] WIP (cpu) --- mlx/backend/cpu/compiled.cpp | 6 +++--- mlx/backend/cpu/inverse.cpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 53d6c47e3e..296390a32e 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -49,7 +49,7 @@ static CompilerCache& cache() { // GPU compile is always available if the GPU is available and since we are in // this file CPU compile is also available. namespace detail { -bool compile_available_for_device(const Device& device) { +bool compile_available_for_device(const Device& /* device */) { return true; } @@ -168,7 +168,7 @@ inline void build_kernel( // Add the input arguments int cnt = 0; int strides_index = 1; - for (size_t i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { // Skip constants from the input list if (is_constant(i)) { continue; @@ -238,7 +238,7 @@ inline void build_kernel( } else { os << x.primitive().name(); os << "()("; - for (int i = 0; i < x.inputs().size() - 1; i++) { + for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) { os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; } os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; diff --git a/mlx/backend/cpu/inverse.cpp b/mlx/backend/cpu/inverse.cpp index ddc979daa6..3da657cbef 100644 --- a/mlx/backend/cpu/inverse.cpp +++ b/mlx/backend/cpu/inverse.cpp @@ -122,7 +122,7 @@ void inverse_impl( stream); const int N = a.shape(-1); - const size_t num_matrices = a.size() / (N * N); + const int64_t num_matrices = a.size() / (N * N); auto& encoder = cpu::get_command_encoder(stream); encoder.set_output_array(inv); @@ -130,13 +130,13 @@ void inverse_impl( auto inv_ptr = inv.data(); if (tri) { encoder.dispatch([inv_ptr, N, num_matrices, upper]() { - for (int i = 0; i < num_matrices; i++) { + for (int64_t i = 0; i < num_matrices; i++) { tri_inv(inv_ptr + N * N * i, N, upper); } }); } else { encoder.dispatch([inv_ptr, N, num_matrices]() { - for (int i = 0; i < num_matrices; i++) { + for (int64_t i = 0; i < num_matrices; i++) { general_inv(inv_ptr + N * N * i, N); } }); From 979abf462bbc5ac3ce6f0e52161de9a12f8782ff Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 09:43:29 -0700 Subject: [PATCH 16/30] WIP (metal) --- mlx/backend/metal/compiled.cpp | 14 +++++------ mlx/backend/metal/conv.cpp | 2 +- mlx/backend/metal/custom_kernel.cpp | 14 +++++------ mlx/backend/metal/device.cpp | 5 +++- mlx/backend/metal/device.h | 2 +- mlx/backend/metal/fence.cpp | 8 +++--- mlx/backend/metal/fft.cpp | 22 ++++++++--------- mlx/backend/metal/hadamard.cpp | 2 +- mlx/backend/metal/indexing.cpp | 38 ++++++++++++++--------------- mlx/backend/metal/matmul.cpp | 2 +- mlx/backend/metal/nojit_kernels.cpp | 4 +-- mlx/backend/metal/normalization.cpp | 6 ++--- mlx/backend/metal/primitives.cpp | 30 +++++++++++++---------- mlx/backend/metal/reduce.cpp | 14 +++++------ mlx/backend/metal/resident.cpp | 2 +- mlx/backend/metal/slicing.cpp | 2 +- mlx/backend/metal/utils.h | 7 ++++++ 17 files changed, 94 insertions(+), 80 deletions(-) diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index eb51ab750e..e2173dc87f 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -109,7 +109,7 @@ inline void build_kernel( // Read constant / contiguous inputs in tmps std::vector nc_inputs; - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { auto& x = inputs[i]; auto& xname = namer.get_name(x); @@ -134,7 +134,7 @@ inline void build_kernel( } // Initialize the indices for non-contiguous inputs - for (int i = 0; i < nc_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(nc_inputs); ++i) { auto& xname = namer.get_name(nc_inputs[i]); os += fmt::format(" {0} index_{1} = ", idx_type, xname); if (ndim == 1) { @@ -174,7 +174,7 @@ inline void build_kernel( os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3); } os += " uint l = zpos % output_shape[d];\n"; - for (int i = 0; i < nc_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(nc_inputs); ++i) { auto& xname = namer.get_name(nc_inputs[i]); os += fmt::format(" index_{0} += ", xname); if (dynamic_dims) { @@ -195,7 +195,7 @@ inline void build_kernel( } // Read non-contiguous inputs into tmps - for (int i = 0; i < nc_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(nc_inputs); ++i) { auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); os += fmt::format( @@ -214,7 +214,7 @@ inline void build_kernel( } else { os += x.primitive().name(); os += "()("; - for (int i = 0; i < x.inputs().size() - 1; i++) { + for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) { os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); } os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back())); @@ -227,7 +227,7 @@ inline void build_kernel( } // Increment indices and close per thread loop if (work_per_thread > 1) { - for (int i = 0; i < nc_inputs.size(); ++i) { + for (int i = 0; i < std::ssize(nc_inputs); ++i) { auto& x = nc_inputs[i]; auto& xname = namer.get_name(x); if (!dynamic_dims) { @@ -396,7 +396,7 @@ void Compiled::eval_gpu( int cnt = 0; int stride_idx = 1; // idx 0 is the output strides Strides in_strides; - for (int i = 0; i < inputs.size(); i++) { + for (int i = 0; i < std::ssize(inputs); i++) { if (is_constant_(i)) { continue; } diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index b4a674ff0e..e09a3175cc 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -990,7 +990,7 @@ void conv_3D_gpu( const std::vector& wt_dilation, const std::vector& in_dilation, bool flip, - std::vector& copies) { + std::vector& /* copies */) { // Make conv params MLXConvParams<3> conv_params{ /* const int N = */ static_cast(in.shape(0)), diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index c48b93c913..deaf1f0f67 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -68,7 +68,7 @@ std::string write_signature( int index = 0; constexpr int max_constant_array_size = 8; // Add inputs - for (int i = 0; i < inputs.size(); ++i) { + for (int i = 0; i < std::ssize(inputs); ++i) { const auto& name = input_names[i]; const auto& arr = inputs[i]; auto dtype = get_type_string(arr.dtype()); @@ -109,7 +109,7 @@ std::string write_signature( } } // Add outputs - for (int i = 0; i < output_names.size(); ++i) { + for (int i = 0; i < std::ssize(output_names); ++i) { const auto& name = output_names[i]; const auto& dtype = output_dtypes[i]; kernel_source += " device "; @@ -126,8 +126,8 @@ std::string write_signature( kernel_source += " [[buffer("; kernel_source += std::to_string(index); kernel_source += ")]]"; - if (index < inputs.size() + output_names.size() - 1 || - attributes.size() > 0) { + if (index < std::ssize(inputs) + std::ssize(output_names) - 1 || + std::ssize(attributes) > 0) { kernel_source += ",\n"; } else { kernel_source += ") {\n"; @@ -138,7 +138,7 @@ std::string write_signature( index = 0; for (const auto& attr : attributes) { kernel_source += attr; - if (index < attributes.size() - 1) { + if (index < std::ssize(attributes) - 1) { kernel_source += ",\n"; } else { kernel_source += ") {\n"; @@ -381,7 +381,7 @@ void CustomKernel::eval_gpu( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); int index = 0; - for (int i = 0; i < checked_inputs.size(); i++) { + for (int i = 0; i < std::ssize(checked_inputs); i++) { const array& in = checked_inputs[i]; auto& shape_info = shape_infos_[i]; compute_encoder.set_input_array(in, index); @@ -408,7 +408,7 @@ void CustomKernel::eval_gpu( } const auto [tx, ty, tz] = threadgroup_; - auto tg_size = tx * ty * tz; + unsigned long tg_size = tx * ty * tz; auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup(); if (tg_size > max_tg_size) { std::ostringstream msg; diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e82d734a2e..5465603df6 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -127,6 +127,9 @@ std::pair load_swiftpm_library( } } } +#else + (void)device; + (void)lib_name; #endif return {nullptr, nullptr}; } @@ -713,7 +716,7 @@ MTL::LinkedFunctions* Device::get_linked_functions_( auto lfuncs = MTL::LinkedFunctions::linkedFunctions(); std::vector objs(funcs.size()); - for (int i = 0; i < funcs.size(); i++) { + for (int i = 0; i < std::ssize(funcs); i++) { objs[i] = funcs[i]; } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index fefb7cdc0c..663e04cd8e 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -137,7 +137,7 @@ struct DeviceStream { // Data updated between command buffers MTL::CommandBuffer* buffer{nullptr}; int buffer_ops{0}; - size_t buffer_sizes{0}; + int64_t buffer_sizes{0}; // The command encoder, fence, and temporaries are updated between command // encoders diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 5abdf7309e..8068e84a48 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -76,7 +76,7 @@ void Fence::wait(Stream stream, const array& x) { auto command_buffer = d.get_command_buffer(idx); command_buffer->encodeWait(static_cast(f.fence), f.count); command_buffer->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {}); return; } @@ -96,7 +96,7 @@ void Fence::wait(Stream stream, const array& x) { compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {}); } void Fence::update(Stream stream, const array& x) { @@ -124,7 +124,7 @@ void Fence::update(Stream stream, const array& x) { command_buffer->encodeSignalEvent( static_cast(f.fence), f.count); command_buffer->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {}); return; } @@ -154,7 +154,7 @@ void Fence::update(Stream stream, const array& x) { compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + [fence_ = fence_](MTL::CommandBuffer* /* cbuf */) {}); } } // namespace mlx::core diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index d329a4685e..74165910cd 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -60,7 +60,7 @@ struct FourStepParams { void fft_op( const array& in, array& out, - size_t axis, + int64_t axis, bool inverse, bool real, const FourStepParams four_step_params, @@ -93,7 +93,7 @@ std::vector plan_stockham_fft(int n) { if (n == 1) { return plan; } - for (int i = 0; i < radices.size(); i++) { + for (int i = 0; i < std::ssize(radices); i++) { int radix = radices[i]; // Manually tuned radices for powers of 2 if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) { @@ -181,7 +181,7 @@ int compute_elems_per_thread(FFTPlan plan) { steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end()); steps.insert(steps.end(), plan.rader.begin(), plan.rader.end()); std::set used_radices; - for (int i = 0; i < steps.size(); i++) { + for (int i = 0; i < std::ssize(steps); i++) { int radix = radices[i % radices.size()]; if (steps[i] > 0) { used_radices.insert(radix); @@ -260,7 +260,7 @@ int primitive_root(int n) { std::tuple compute_raders_constants( int rader_n, - const Stream& s) { + const Stream& /* s */) { int proot = primitive_root(rader_n); // Fermat's little theorem int inv = mod_exp(proot, rader_n - 2, rader_n); @@ -508,7 +508,7 @@ void four_step_fft( void fft_op( const array& in, array& out, - size_t axis, + int64_t axis, bool inverse, bool real, const FourStepParams four_step_params, @@ -612,11 +612,11 @@ void fft_op( // Start of radix/rader step constants int index = 4; - for (int i = 0; i < plan.stockham.size(); i++) { + for (int i = 0; i < std::ssize(plan.stockham); i++) { func_consts.push_back(make_int(&plan.stockham[i], index)); index += 1; } - for (int i = 0; i < plan.rader.size(); i++) { + for (int i = 0; i < std::ssize(plan.rader); i++) { func_consts.push_back(make_int(&plan.rader[i], index)); index += 1; } @@ -771,8 +771,8 @@ void nd_fft_op( array temp1(temp_shape, complex64, nullptr, {}); array temp2(temp_shape, complex64, nullptr, {}); std::vector temp_arrs = {temp1, temp2}; - for (int i = axes.size() - 1; i >= 0; i--) { - int reverse_index = axes.size() - i - 1; + for (int i = std::ssize(axes) - 1; i >= 0; i--) { + int reverse_index = std::ssize(axes) - i - 1; // For 5D and above, we don't want to reallocate our two temporary arrays bool inplace = reverse_index >= 3 && i != 0; // Opposite order for fft vs ifft @@ -780,8 +780,8 @@ void nd_fft_op( size_t axis = axes[index]; // Mirror np.fft.(i)rfftn and perform a real transform // only on the final axis. - bool step_real = (real && index == axes.size() - 1); - const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; + bool step_real = (real && index == std::ssize(axes) - 1); + const array& in_arr = i == std::ssize(axes) - 1 ? in : temp_arrs[1 - i % 2]; array& out_arr = i == 0 ? out : temp_arrs[i % 2]; fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); } diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 65a8771513..bf115c630d 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -43,7 +43,7 @@ std::string gen_hadamard_codelet(int m) { while (end != std::string_view::npos) { source << " tmp[" << index << "] = "; auto row = matrix.substr(start, end - start); - for (int i = 0; i < row.length(); i++) { + for (int i = 0; i < std::ssize(row); i++) { source << " " << row[i] << " x[" << i << "]"; } source << ";" << std::endl; diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 0250987573..8a215267ba 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -52,7 +52,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - size_t slice_size = 1; + int64_t slice_size = 1; for (auto s : slice_sizes_) { slice_size *= s; } @@ -94,8 +94,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kernel_name, lib); compute_encoder.set_compute_pipeline_state(kernel); - size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread; - size_t dim_y = indices.size(); + int64_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread; + int64_t dim_y = indices.size(); auto group_dims = get_block_dims(dim_x, dim_y, 1); MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1); @@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { } int idx_ndim = nidx ? inputs[1].ndim() : 0; - size_t ndim = src.ndim(); + int64_t ndim = src.ndim(); std::string kernel_name = fmt::format( "gather{0}{1}_{2}_{3}_{4}", @@ -149,8 +149,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { // Launch 3D grid of threads // First two dimensions for the indices, the last one for the slice - size_t dim0 = 1; - size_t dim1 = 1; + int64_t dim0 = 1; + int64_t dim1 = 1; if (nidx) { if (inputs[1].ndim() >= 1) { dim0 = inputs[1].shape(0); @@ -159,13 +159,13 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dim1 = inputs[1].size() / dim0; } } - size_t dim2 = slice_size; + int64_t dim2 = slice_size; auto group_dims = get_block_dims(dim0, dim1, dim2); MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2); // Collect all idx shapes and strides into one place std::vector idx_shapes; - std::vector idx_strides; + std::vector idx_strides; std::vector idx_contigs; for (int i = 0; i < nidx; ++i) { idx_shapes.insert( @@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); int idx_ndim = nidx ? inputs[1].ndim() : 0; - size_t idx_size = nidx ? inputs[1].size() : 1; + int64_t idx_size = nidx ? inputs[1].size() : 1; auto idx_to_out = idx_size / out.size(); int nwork; @@ -345,7 +345,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); - size_t nthreads = upd.size(); + int64_t nthreads = upd.size(); compute_encoder.set_compute_pipeline_state(kernel); @@ -354,8 +354,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set update info - size_t upd_ndim = upd.ndim(); - size_t upd_size = 1; + int64_t upd_ndim = upd.ndim(); + int64_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } @@ -391,7 +391,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_bytes(upd_size, 6); // Set output info - size_t out_ndim = out.ndim(); + int64_t out_ndim = out.ndim(); if (out_ndim == 0) { // Need placeholders so Metal doesn't complain int shape_ = 0; @@ -448,7 +448,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - size_t ndim = src.ndim(); + int64_t ndim = src.ndim(); bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; @@ -486,8 +486,8 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_compute_pipeline_state(kernel); // Grid [size post, index size, size pre] - size_t size_pre = 1; - size_t size_post = 1; + int64_t size_pre = 1; + int64_t size_post = 1; for (int i = 0; i < axis_; ++i) { size_pre *= idx.shape(i); } @@ -541,7 +541,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - size_t ndim = src.ndim(); + int64_t ndim = src.ndim(); bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; @@ -602,8 +602,8 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_compute_pipeline_state(kernel); // Grid [size post, index size, size pre] - size_t size_pre = 1; - size_t size_post = 1; + int64_t size_pre = 1; + int64_t size_post = 1; for (int i = 0; i < axis_; ++i) { size_pre *= idx.shape(i); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index d6bee651d2..4a8f3d77dd 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -344,7 +344,7 @@ void steel_gemm_splitk_axpby( int M, int N, int K, - int batch_size_out, + int /* batch_size_out */, int lda, int ldb, bool transpose_a, diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 109dd8df78..46d2cb5e05 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -179,8 +179,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( metal::Device& d, const std::string& kernel_name, const array&, - const std::optional& mask_out, - const std::optional& mask_op, + const std::optional& /* mask_out */, + const std::optional& /* mask_op */, bool, bool, int, diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index da0160b249..4c277c52ea 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -134,7 +134,7 @@ void RMSNormVJP::eval_gpu( d.add_temporary(g, s.index); } - auto axis_size = static_cast(x.shape().back()); + auto axis_size = x.shape().back(); int n_rows = x.data_size() / axis_size; // Allocate the gradient accumulator gw and a temporary to store the @@ -246,7 +246,7 @@ void LayerNorm::eval_gpu( const array& w = inputs[1]; const array& b = inputs[2]; - auto axis_size = static_cast(x.shape().back()); + auto axis_size = x.shape().back(); int n_rows = x.data_size() / axis_size; int simd_size = 32; @@ -344,7 +344,7 @@ void LayerNormVJP::eval_gpu( d.add_temporary(g, s.index); } - auto axis_size = static_cast(x.shape().back()); + auto axis_size = x.shape().back(); int n_rows = x.data_size() / axis_size; // Allocate a temporary to store the gradients for w and allocate the output diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 5f6376c5e8..930114ecc5 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -152,7 +152,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } } -void Load::eval_gpu(const std::vector& inputs, array& out) { +void Load::eval_gpu(const std::vector& /* inputs */, array& /* out */) { throw std::runtime_error("[Load::eval_gpu] Not implemented."); } @@ -201,41 +201,45 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { } void QRF::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI."); } void SVD::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI."); } -void Inverse::eval_gpu(const std::vector& inputs, array& output) { +void Inverse::eval_gpu( + const std::vector& /* inputs */, + array& /* output */) { throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI."); } -void Cholesky::eval_gpu(const std::vector& inputs, array& out) { +void Cholesky::eval_gpu( + const std::vector& /* inputs */, + array& /* out */) { throw std::runtime_error( "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } void Eig::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI."); } void Eigh::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); } void LUF::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& /* inputs */, + std::vector& /* outputs */) { throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); } diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 504943d823..2443bf96c0 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -291,7 +291,7 @@ void init_reduce( const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { auto [_, out_type] = remap_reduce_types(out, op_name); const std::string func_name = "init_reduce"; std::string kname = func_name; @@ -397,7 +397,7 @@ void row_reduce_small( RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { // Set the kernel int n = get_kernel_reduce_ndim(args.reduce_ndim); auto [in_type, out_type] = remap_reduce_types(in, op_name); @@ -453,7 +453,7 @@ void row_reduce_simple( RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { // Set the kernel auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "row_reduce_simple"; @@ -493,7 +493,7 @@ void row_reduce_looped( RowReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Set the kernel @@ -570,7 +570,7 @@ void strided_reduce_small( ColReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Figure out the grid dims @@ -747,7 +747,7 @@ void strided_reduce_looped( ColReduceArgs& args, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { + const Stream& /* s */) { auto [in_type, out_type] = remap_reduce_types(in, op_name); // Prepare the arguments for the kernel @@ -959,7 +959,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Continue with reduction operation // Minimum of 4 bytes since we use size 4 structs for all reduce // and metal will complain o/w - size_t min_bytes = std::max(out.nbytes(), 4ul); + size_t min_bytes = std::max(out.nbytes(), 4); out.set_data(allocator::malloc(min_bytes)); std::string op_name; switch (reduce_type_) { diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 798824c2fb..cc7e6af082 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -80,7 +80,7 @@ void ResidencySet::resize(size_t size) { // Remove wired allocations until under capacity auto allocations = wired_set_->allAllocations(); auto num_allocations = wired_set_->allocationCount(); - for (int i = 0; i < num_allocations && current_size > size; ++i) { + for (size_t i = 0; i < num_allocations && current_size > size; ++i) { auto buf = static_cast(allocations->object(i)); wired_set_->removeAllocation(buf); current_size -= buf->allocatedSize(); diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 1e14c35c8a..97087c2562 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -33,7 +33,7 @@ void concatenate_gpu( auto& d = metal::device(s.device); auto& compute_encoder = d.get_command_encoder(s.index); auto concurrent_ctx = compute_encoder.start_concurrent(); - for (int i = 0; i < inputs.size(); i++) { + for (int i = 0; i < std::ssize(inputs); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis] * sizes[i]; out_slice.copy_shared_buffer( diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index e7784e5997..74d2fc2442 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -29,6 +29,10 @@ inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) { std::ostringstream label; label << "Stream " << index; queue->setLabel(make_string(label)); +#else + // appease warnings + (void)queue; + (void)index; #endif } @@ -42,6 +46,9 @@ inline void debug_set_primitive_buffer_label( } label << primitive.name(); command_buffer->setLabel(make_string(label)); +#else + (void)command_buffer; + (void)primitive; #endif } From 6343622c6783a6167744941f6a3d49b3afdf93ff Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 11:46:36 -0700 Subject: [PATCH 17/30] fix small vector indexing checks --- mlx/small_vector.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/small_vector.h b/mlx/small_vector.h index cf3467cbc9..e0a0a76b40 100644 --- a/mlx/small_vector.h +++ b/mlx/small_vector.h @@ -302,7 +302,7 @@ class SmallVector { } T& at(int index) { - if (index >= size()) { + if (index < 0 || index >= size()) { throw std::out_of_range("SmallVector out of range."); } return begin_[index]; @@ -312,7 +312,7 @@ class SmallVector { } T& operator[](int index) { - assert(size() > index); + assert(index >= 0 && size() > index); return begin_[index]; } const T& operator[](int index) const { From 8d10f3ec75109f4dcbc2b6f432929ac6001e6d75 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 11:47:03 -0700 Subject: [PATCH 18/30] WIP (metal) --- mlx/backend/metal/scan.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 0b636c1624..0907be2a8b 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -76,7 +76,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); - size_t size = in.shape(axis_); + int64_t size = in.shape(axis_); compute_encoder.set_bytes(size, 2); // Compute the thread grid From b0d985416aed3de70bc9fa63eaf08b6f9ee029eb Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 13:13:15 -0700 Subject: [PATCH 19/30] fix arg_reduce --- mlx/backend/cpu/arg_reduce.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index 41ab9fb609..3f42e71833 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -17,12 +17,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) { Strides strides = remove_index(in.strides(), axis); Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); - auto out_ptr = out.data(); + auto out_ptr = out.data(); for (int64_t i = 0; i < out.size(); ++i) { auto loc = elem_to_loc(i, shape, strides); auto local_in_ptr = in_ptr + loc; - int64_t ind_v = 0; + uint32_t ind_v = 0; InT v = (*local_in_ptr); for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { op(j, (*local_in_ptr), &ind_v, &v); From 8277e71ea98a1b4429ac988091179e20141a3177 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 13:19:54 -0700 Subject: [PATCH 20/30] WIP (gpu) --- mlx/backend/gpu/primitives.cpp | 2 +- mlx/backend/gpu/slicing.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index ee40799df6..837c898fc0 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -51,7 +51,7 @@ void Contiguous::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Contiguous::eval_gpu"); assert(inputs.size() == 1); auto& in = inputs[0]; - constexpr size_t extra_bytes = 16384; + constexpr int64_t extra_bytes = 16384; if (in.buffer_size() <= out.nbytes() + extra_bytes && (in.flags().row_contiguous || (allow_col_major_ && in.flags().col_contiguous))) { diff --git a/mlx/backend/gpu/slicing.cpp b/mlx/backend/gpu/slicing.cpp index fde2a01cd6..76c8b843c1 100644 --- a/mlx/backend/gpu/slicing.cpp +++ b/mlx/backend/gpu/slicing.cpp @@ -11,7 +11,7 @@ void slice_gpu( array& out, const Shape& start_indices, const Shape& strides, - const Stream& s) { + const Stream& /* s */) { slice(in, out, start_indices, strides); } @@ -27,7 +27,7 @@ void pad_gpu( // Find offset for start of input values size_t data_offset = 0; - for (int i = 0; i < axes.size(); i++) { + for (int i = 0; i < std::ssize(axes); i++) { auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; data_offset += out.strides()[ax] * low_pad_size[i]; } From b48d29820513bd247ab615c9d20cfde232264210 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 13:20:09 -0700 Subject: [PATCH 21/30] WIP (distributed) --- mlx/distributed/distributed.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index d71ebb9b12..9c33837d01 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -55,7 +55,8 @@ class EmptyGroup : public GroupImpl { return 1; } - std::shared_ptr split(int color, int key = -1) override { + std::shared_ptr split(int /* color */, int /* key */ = -1) + override { throw std::runtime_error("Cannot split the distributed group further."); } From 4a1b1796b74811bc84a44809556416405ae53857 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 13:20:47 -0700 Subject: [PATCH 22/30] WIP (io) --- mlx/io/gguf.cpp | 4 ++-- mlx/io/gguf_quants.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 206f6fb31f..096be49b05 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -238,7 +238,7 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { return array_map; } -GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { +GGUFLoad load_gguf(const std::string& file, StreamOrDevice /* s */) { bool exists; { std::ifstream f(file.c_str()); @@ -440,7 +440,7 @@ void save_gguf( } const char* tensorname = key.c_str(); const uint64_t namelen = key.length(); - const uint32_t num_dim = arr.ndim(); + const int num_dim = arr.ndim(); std::vector dim(num_dim); for (int i = 0; i < num_dim; i++) { dim[i] = arr.shape()[num_dim - 1 - i]; diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 148ed6c479..a05c7447ea 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -77,8 +77,8 @@ void extract_q8_0_data( array& weights_arr, array& scales_arr, array& biases_arr) { - const uint64_t weights_per_block = 32; - const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights + const int64_t weights_per_block = 32; + const int64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights auto data = static_cast(tensor.weights_data); auto weights = weights_arr.data(); auto scales = scales_arr.data(); From 19ab7911f61dbe427f1273731cccad5dce6ad629 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 13:32:43 -0700 Subject: [PATCH 23/30] WIP (cuda) --- mlx/backend/cuda/no_cuda.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cuda/no_cuda.cpp b/mlx/backend/cuda/no_cuda.cpp index 175a505b4b..d81b4ebaa6 100644 --- a/mlx/backend/cuda/no_cuda.cpp +++ b/mlx/backend/cuda/no_cuda.cpp @@ -35,9 +35,9 @@ std::vector precompiled_cuda_kernel( const std::vector&, std::tuple, std::tuple, - int shared_memory, - std::optional init_value, - bool ensure_row_contiguous, + int /* shared_memory */, + std::optional /* init_value */, + bool /* ensure_row_contiguous */, StreamOrDevice) { throw std::runtime_error("[cuda_kernel] No CUDA back-end."); } From c5913131cfe9dde6778e80ad31cb0c93479a60c8 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 13:32:56 -0700 Subject: [PATCH 24/30] WIP (distributed) --- mlx/distributed/mpi/mpi.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 494fb02dcc..ecdf5c7935 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -36,7 +36,7 @@ void simple_sum( void* input, void* accumulator, int* len, - MPI_Datatype* datatype) { + MPI_Datatype* /* datatype */) { T* in = (T*)input; T* acc = (T*)accumulator; int N = *len; @@ -55,7 +55,7 @@ void simple_max( void* input, void* accumulator, int* len, - MPI_Datatype* datatype) { + MPI_Datatype* /* datatype */) { T* in = (T*)input; T* acc = (T*)accumulator; int N = *len; @@ -75,7 +75,7 @@ void simple_min( void* input, void* accumulator, int* len, - MPI_Datatype* datatype) { + MPI_Datatype* /* datatype */) { T* in = (T*)input; T* acc = (T*)accumulator; int N = *len; From 7107802e091745a39bad71a4e199f174a35bc1a6 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 16:23:51 -0700 Subject: [PATCH 25/30] WIP (examples) --- examples/cpp/tutorial.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/cpp/tutorial.cpp b/examples/cpp/tutorial.cpp index ae2cd3cfbe..f10a4bb5d3 100644 --- a/examples/cpp/tutorial.cpp +++ b/examples/cpp/tutorial.cpp @@ -14,14 +14,17 @@ void array_basics() { // Get the value out of it: auto s = x.item(); assert(s == 1.0); + (void)s; // Scalars have a size of 1: - size_t size = x.size(); + int64_t size = x.size(); assert(size == 1); + (void)size; // Scalars have 0 dimensions: int ndim = x.ndim(); assert(ndim == 0); + (void)ndim; // The shape should be an empty vector: auto shape = x.shape(); @@ -30,6 +33,7 @@ void array_basics() { // The datatype should be float32: auto dtype = x.dtype(); assert(dtype == mx::float32); + (void)dtype; // Specify the dtype when constructing the array: x = mx::array(1, mx::int32); From ac75c87fd75839cb3520412ba6274a1d16b3f9df Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 16:24:09 -0700 Subject: [PATCH 26/30] WIP (cpu) --- mlx/backend/cpu/distributed.cpp | 1 + mlx/backend/cpu/primitives.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index d641d581ba..8792eb1550 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -90,6 +90,7 @@ void Recv::eval_cpu( std::vector& outputs) { assert(inputs.size() == 0); assert(outputs.size() == 1); + (void)inputs; outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); distributed::detail::recv(group(), outputs[0], src_, stream()); diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index 18db2b3ddd..889a11ac3d 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -124,6 +124,7 @@ void Transpose::eval_cpu(const std::vector& inputs, array& out) { void Arange::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); + (void)inputs; out.set_data(allocator::malloc(out.nbytes())); switch (out.dtype()) { case bool_: From 8d13a0bc6ba0b4508a612aaab3b63469c62d5a08 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 16:24:21 -0700 Subject: [PATCH 27/30] WIP (metal) --- mlx/backend/metal/primitives.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 930114ecc5..77058ee3ca 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -26,6 +26,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { void Arange::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); + (void)inputs; out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; From 18aa9213886b0e4ed5d87a394fad0dbbe839d4e1 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 16:24:35 -0700 Subject: [PATCH 28/30] WIP --- mlx/fast.cpp | 2 ++ mlx/primitives.cpp | 67 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index e88527a8e9..9b6a70d386 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -127,6 +127,7 @@ std::vector RMSNorm::vjp( assert(primals.size() == 2); assert(outputs.size() == 1); assert(cotangents.size() == 1); + (void)outputs; auto s = stream(); auto fallback = [eps = eps_, s](const std::vector& inputs) { @@ -269,6 +270,7 @@ std::vector LayerNorm::vjp( assert(primals.size() == 3); assert(outputs.size() == 1); assert(cotangents.size() == 1); + (void)outputs; auto s = stream(); auto fallback = [eps = eps_, s](const std::vector& inputs) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0faa9f4074..1c1ecdb48b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -230,6 +230,7 @@ std::vector Abs::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], sign(primals[0], stream()), stream())}; } @@ -383,6 +384,7 @@ std::vector ArcCos::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); array denom = negative(rsqrt(t, stream()), stream()); @@ -411,6 +413,7 @@ std::vector ArcCosh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = subtract(square(primals[0], stream()), one, stream()); return {multiply(tangents[0], rsqrt(t, stream()), stream())}; @@ -438,6 +441,7 @@ std::vector ArcSin::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); return {multiply(tangents[0], rsqrt(t, stream()), stream())}; @@ -465,6 +469,7 @@ std::vector ArcSinh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = add(square(primals[0], stream()), one, stream()); return {multiply(tangents[0], rsqrt(t, stream()), stream())}; @@ -492,6 +497,7 @@ std::vector ArcTan::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = add(one, square(primals[0], stream()), stream()); return {divide(tangents[0], t, stream())}; @@ -539,6 +545,7 @@ std::vector ArcTan2::jvp( const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() == 2); + (void)argnums; const auto& s = stream(); const array& x1 = primals[0]; @@ -575,6 +582,7 @@ std::vector ArcTanh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); return {divide(tangents[0], t, stream())}; @@ -725,6 +733,7 @@ std::vector AsStrided::vjp( const std::vector& argnums, const std::vector&) { assert(argnums.size() == 1); + (void)argnums; // Extract the sizes and cast them to ints int grad_size = primals[0].size(); @@ -754,6 +763,7 @@ std::vector AsStrided::jvp( const std::vector& tangents, const std::vector& /* argnums */) { assert(primals.size() == 1); + (void)primals; return {as_strided(tangents[0], shape_, strides_, offset_, stream())}; } @@ -787,6 +797,7 @@ std::vector BitwiseBinary::jvp( const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 2); + (void)primals; std::vector vjps = {zeros_like(tangents[0], stream())}; if (argnums.size() > 1) { vjps.push_back(vjps.back()); @@ -942,6 +953,7 @@ std::vector Ceil::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -1581,6 +1593,8 @@ std::vector Copy::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return cotangents; } @@ -1590,6 +1604,8 @@ std::vector Copy::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return tangents; } @@ -1615,6 +1631,7 @@ std::vector Cos::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply( tangents[0], negative(sin(primals[0], stream()), stream()), stream())}; } @@ -1641,6 +1658,7 @@ std::vector Cosh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], sinh(primals[0], stream()), stream())}; } @@ -1881,6 +1899,7 @@ std::vector Erf::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto dtype = primals[0].dtype(); auto scale = multiply(array(M_2_SQRTPI, dtype), tangents[0], stream()); return {multiply( @@ -1915,6 +1934,7 @@ std::vector ErfInv::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto dtype = primals[0].dtype(); auto scale = multiply(array(1.0 / M_2_SQRTPI, dtype), tangents[0], stream()); return {multiply( @@ -1945,6 +1965,7 @@ std::vector Exp::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], exp(primals[0], stream()), stream())}; } @@ -1973,6 +1994,7 @@ std::vector Expm1::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], exp(primals[0], stream()), stream())}; } @@ -2181,6 +2203,7 @@ std::vector FFT::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto& in = primals[0]; std::vector axes(axes_.begin(), axes_.end()); @@ -2260,6 +2283,8 @@ std::vector FFT::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; auto& tan = tangents[0]; if (real_ & inverse_) { return {fft::irfftn(tan, stream())}; @@ -2286,6 +2311,7 @@ std::vector Floor::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -2304,6 +2330,7 @@ std::vector Full::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(cotangents[0], primals[0], stream())}; } @@ -2313,6 +2340,8 @@ std::vector Full::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return tangents; } @@ -2568,6 +2597,7 @@ std::vector Imag::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply( array(complex64_t{0.0f, 1.0f}, primals[0].dtype()), cotangents[0], @@ -2580,6 +2610,8 @@ std::vector Imag::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {imag(tangents[0], stream())}; } @@ -2659,6 +2691,7 @@ std::vector Log::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto out = divide(tangents[0], primals[0], stream()); if (base_ != Base::e) { auto scale = 1 / std::log(base_ == Base::ten ? 10.0f : 2.0f); @@ -2696,6 +2729,7 @@ std::vector Log1p::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto dtype = primals[0].dtype(); return {divide( tangents[0], add(array(1.0f, dtype), primals[0], stream()), stream())}; @@ -2723,6 +2757,8 @@ std::vector LogicalNot::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {zeros_like(tangents[0], stream())}; } @@ -2740,6 +2776,7 @@ std::vector LogicalAnd::vjp( const std::vector& argnums, const std::vector&) { assert(primals.size() == 2); + (void)primals; std::vector vjps = {zeros_like(cotangents[0], stream())}; if (argnums.size() > 1) { vjps.push_back(vjps.back()); @@ -2753,6 +2790,7 @@ std::vector LogicalAnd::jvp( const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() <= 2); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -2772,6 +2810,7 @@ std::vector LogicalOr::vjp( const std::vector& argnums, const std::vector&) { assert(primals.size() == 2); + (void)primals; std::vector vjps = {zeros_like(cotangents[0], stream())}; if (argnums.size() > 1) { vjps.push_back(vjps.back()); @@ -2785,6 +2824,7 @@ std::vector LogicalOr::jvp( const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() <= 2); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -3154,6 +3194,8 @@ std::vector Negative::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {negative(tangents[0], stream())}; } @@ -3198,6 +3240,7 @@ std::vector Pad::vjp( const std::vector& argnums, const std::vector&) { assert(argnums.size() == 1 && argnums[0] == 0); + (void)argnums; auto& cotan = cotangents[0]; Shape start(cotan.ndim(), 0); @@ -3218,6 +3261,7 @@ std::vector Pad::jvp( const std::vector& tangents, const std::vector& argnums) { assert(argnums.size() == 1 && argnums[0] == 0); + (void)argnums; return { pad(tangents[0], @@ -3639,6 +3683,7 @@ std::vector Real::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {astype(cotangents[0], primals[0].dtype(), stream())}; } @@ -3648,6 +3693,8 @@ std::vector Real::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {real(tangents[0], stream())}; } @@ -3688,6 +3735,7 @@ std::vector Reshape::vjp( assert(primals.size() == 1); assert(argnums.size() == 1); assert(argnums[0] == 0); + (void)argnums; return {reshape(cotangents[0], primals[0].shape(), stream())}; } @@ -3698,6 +3746,8 @@ std::vector Reshape::jvp( assert(primals.size() == 1); assert(argnums.size() == 1); assert(argnums[0] == 0); + (void)primals; + (void)argnums; return {reshape(tangents[0], shape_, stream())}; } @@ -3891,6 +3941,7 @@ std::vector Round::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -3926,6 +3977,7 @@ std::vector Scan::vjp( const std::vector& outputs) { assert(primals.size() == 1); assert(argnums[0] == 0); + (void)argnums; if (reduce_type_ == Scan::Sum) { return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; @@ -4027,6 +4079,7 @@ std::vector Scan::jvp( const std::vector& argnums) { assert(tangents.size() == 1); assert(argnums[0] == 0); + (void)argnums; if (reduce_type_ == Scan::Sum) { return {cumsum(tangents[0], axis_, reverse_, inclusive_, stream())}; @@ -4346,6 +4399,7 @@ std::vector Sigmoid::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto s = sigmoid(primals[0], stream()); auto sprime = multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream()); @@ -4374,6 +4428,7 @@ std::vector Sign::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {zeros(primals[0].shape(), primals[0].dtype(), stream())}; } @@ -4399,6 +4454,7 @@ std::vector Sin::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], cos(primals[0], stream()), stream())}; } @@ -4424,6 +4480,7 @@ std::vector Sinh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], cosh(primals[0], stream()), stream())}; } @@ -4469,6 +4526,7 @@ std::vector Slice::jvp( const std::vector& /* argnums */) { // Check inputs assert(primals.size() == 1); + (void)primals; return {slice(tangents[0], start_indices_, end_indices_, strides_, stream())}; } @@ -4566,6 +4624,7 @@ std::vector SliceUpdate::jvp( const std::vector& /* argnums */) { // Check inputs assert(primals.size() == 2); + (void)primals; return {slice_update( tangents[0], tangents[1], @@ -4748,6 +4807,7 @@ std::vector Softmax::vjp( const std::vector& outputs) { assert(primals.size() == 1); assert(cotangents.size() == 1); + (void)primals; auto& s = outputs[0]; auto sv = multiply(s, cotangents[0], stream()); return {subtract( @@ -5022,6 +5082,7 @@ std::vector Tan::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array cos_sq = square(cos(primals[0], stream()), stream()); return {divide(tangents[0], cos_sq, stream())}; } @@ -5048,6 +5109,7 @@ std::vector Tanh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array cosh_sq = square(cosh(primals[0], stream()), stream()); return {divide(tangents[0], cosh_sq, stream())}; } @@ -5409,6 +5471,8 @@ std::vector Transpose::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; std::vector iaxes(axes_.size()); for (int i = 0; i < std::ssize(axes_); ++i) { iaxes[axes_[i]] = i; @@ -5422,6 +5486,7 @@ std::vector Transpose::jvp( const std::vector& /* argnums */) { assert(primals.size() == 1); assert(tangents.size() == 1); + (void)primals; return {transpose(tangents[0], axes_, stream())}; } @@ -5556,6 +5621,8 @@ std::vector Hadamard::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {hadamard_transform(tangents[0], scale_, stream())}; } From 9f649b565851a407f739ff506fcacf860adfd340 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 16:24:51 -0700 Subject: [PATCH 29/30] WIP (python) --- python/src/array.cpp | 6 +++--- python/src/convert.cpp | 4 ++-- python/src/fast.cpp | 1 - python/src/indexing.cpp | 25 +++++++++++++------------ python/src/load.cpp | 2 +- python/src/metal.cpp | 4 +++- python/src/mlx_func.cpp | 2 +- python/src/stream.cpp | 6 +++--- python/src/transforms.cpp | 25 ++++++++++++------------- python/src/trees.cpp | 27 ++++++++++++++------------- python/src/utils.cpp | 4 ++-- 11 files changed, 54 insertions(+), 52 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 231474c2d9..53287fae0a 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -83,7 +83,7 @@ class ArrayPythonIterator { throw nb::stop_iteration(); } - if (idx_ >= 0 && idx_ < splits_.size()) { + if (idx_ >= 0 && idx_ < std::ssize(splits_)) { return mx::squeeze(splits_[idx_++], 0); } @@ -390,7 +390,7 @@ void init_array(nb::module_& m) { )pbdoc") .def( "__array_namespace__", - [](const mx::array& a, + [](const mx::array& /* a */, const std::optional& api_version) { if (api_version) { throw std::invalid_argument( @@ -501,7 +501,7 @@ void init_array(nb::module_& m) { .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) .def( "__dlpack_device__", - [](const mx::array& a) { + [](const mx::array& /* a */) { // See // https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74 if (mx::metal::is_available()) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 88da37103e..c36f42115a 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -50,7 +50,7 @@ mx::array nd_array_to_mlx( // Compute the shape and size mx::Shape shape; shape.reserve(nd_array.ndim()); - for (int i = 0; i < nd_array.ndim(); i++) { + for (int i = 0; i < static_cast(nd_array.ndim()); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } auto type = nd_array.dtype(); @@ -289,7 +289,7 @@ PyScalarT validate_shape( throw std::invalid_argument("Initialization encountered extra dimension."); } auto s = shape[idx]; - if (nb::len(list) != s) { + if (nb::len(list) != static_cast(s)) { throw std::invalid_argument( "Initialization encountered non-uniform length."); } diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 0ed1aa6984..c270b72464 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -201,7 +201,6 @@ void init_fast(nb::module_& parent_module) { bool has_mask = !std::holds_alternative(mask); bool has_str_mask = has_mask && std::holds_alternative(mask); - bool has_arr_mask = has_mask && std::holds_alternative(mask); if (has_mask) { if (has_str_mask) { diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 95c6e2a185..d425b9e1e0 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -115,7 +115,7 @@ mx::array mlx_gather_nd( std::vector is_slice(indices.size(), false); int num_slices = 0; // gather all the arrays - for (int i = 0; i < indices.size(); i++) { + for (int i = 0; i < std::ssize(indices); i++) { auto& idx = indices[i]; if (nb::isinstance(idx)) { @@ -142,7 +142,7 @@ mx::array mlx_gather_nd( // reshape them so that the int/array indices are first if (gather_first) { int slice_index = 0; - for (int i = 0; i < gather_indices.size(); i++) { + for (int i = 0; i < std::ssize(gather_indices); i++) { if (is_slice[i]) { mx::Shape index_shape(max_dims + num_slices, 1); index_shape[max_dims + slice_index] = gather_indices[i].shape(0); @@ -156,7 +156,7 @@ mx::array mlx_gather_nd( } } else { // reshape them so that the int/array indices are last - for (int i = 0; i < gather_indices.size(); i++) { + for (int i = 0; i < std::ssize(gather_indices); i++) { if (i < num_slices) { mx::Shape index_shape(max_dims + num_slices, 1); index_shape[i] = gather_indices[i].shape(0); @@ -190,7 +190,7 @@ auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) { bool has_ellipsis = false; // Start from dimension 0 till we hit an ellipsis - for (; i < entries.size(); i++) { + for (; i < std::ssize(entries); i++) { auto idx = entries[i]; if (!is_valid_index_type(idx)) { throw std::invalid_argument( @@ -301,7 +301,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { if (have_array) { int last_array; // Then find the last array - for (last_array = indices.size() - 1; last_array >= 0; last_array--) { + for (last_array = std::ssize(indices) - 1; last_array >= 0; + last_array--) { auto& idx = indices[last_array]; if (nb::isinstance(idx) || nb::isinstance(idx)) { break; @@ -333,11 +334,11 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { nb::slice(nb::none(), nb::none(), nb::none())); } } - for (int i = last_array + 1; i < indices.size(); i++) { + for (int i = last_array + 1; i < std::ssize(indices); i++) { remaining_indices.push_back(indices[i]); } } else { - for (int i = 0; i < indices.size(); i++) { + for (int i = 0; i < std::ssize(indices); i++) { auto& idx = indices[i]; if (nb::isinstance(idx) || nb::isinstance(idx)) { break; @@ -352,7 +353,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { remaining_indices.push_back( nb::slice(nb::none(), nb::none(), nb::none())); } - for (int i = last_array + 1; i < indices.size(); i++) { + for (int i = last_array + 1; i < std::ssize(indices); i++) { remaining_indices.push_back(indices[i]); } } @@ -406,7 +407,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { if (unsqueeze_needed || squeeze_needed) { std::vector squeeze_axes; std::vector unsqueeze_axes; - for (int axis = 0; axis < remaining_indices.size(); ++axis) { + for (int axis = 0; axis < std::ssize(remaining_indices); ++axis) { auto& idx = remaining_indices[axis]; if (unsqueeze_needed && idx.is_none()) { unsqueeze_axes.push_back(axis - squeeze_axes.size()); @@ -583,7 +584,7 @@ mlx_scatter_args_nd( } // Analyse the types of the indices - size_t max_dim = 0; + int max_dim = 0; bool arrays_first = false; int num_none = 0; int num_slices = 0; @@ -640,7 +641,7 @@ mlx_scatter_args_nd( std::vector update_shape(non_none_indices, 1); std::vector slice_shapes; - for (int i = 0; i < indices.size(); ++i) { + for (int i = 0; i < std::ssize(indices); ++i) { auto& pyidx = indices[i]; if (nb::isinstance(pyidx)) { mx::ShapeElem start, end, stride; @@ -848,7 +849,7 @@ auto mlx_slice_update( int unspecified = src.ndim() - non_none_indices; std::vector squeeze_dims; std::vector expand_dims; - for (int i = indices.size() - 1, + for (int i = std::ssize(indices) - 1, ax = non_none_indices - 1, upd_ax = upd.ndim() - unspecified - 1; i >= 0; diff --git a/python/src/load.cpp b/python/src/load.cpp index f98307f192..37a674403a 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -436,7 +436,7 @@ void mlx_savez_helper( nb::cast>(kwargs); auto arrays_list = nb::cast>(args); - for (int i = 0; i < arrays_list.size(); i++) { + for (int i = 0; i < std::ssize(arrays_list); i++) { std::string arr_name = "arr_" + std::to_string(i); if (arrays_dict.count(arr_name) > 0) { diff --git a/python/src/metal.cpp b/python/src/metal.cpp index a56674428b..4086c3e28f 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -22,7 +22,9 @@ bool DEPRECATE(const char* old_fn, const char* new_fn) { return true; } -#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn) +#define DEPRECATE(oldfn, newfn) \ + static bool dep = DEPRECATE(oldfn, newfn); \ + (void)dep; void init_metal(nb::module_& m) { nb::module_ metal = m.def_submodule("metal", "mlx.metal"); diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp index 87912fdda3..e127a89df1 100644 --- a/python/src/mlx_func.cpp +++ b/python/src/mlx_func.cpp @@ -107,7 +107,7 @@ nb::callable mlx_func( return nb::steal((PyObject*)r); } -void init_mlx_func(nb::module_& m) { +void init_mlx_func(nb::module_& /* m */) { gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec); if (!gc_func_tp) { nb::raise("Could not register MLX function type."); diff --git a/python/src/stream.cpp b/python/src/stream.cpp index e10f4751c5..033e466604 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -100,9 +100,9 @@ void init_stream(nb::module_& m) { .def( "__exit__", [](PyStreamContext& scm, - const std::optional& exc_type, - const std::optional& exc_value, - const std::optional& traceback) { scm.exit(); }, + const std::optional& /* exc_type */, + const std::optional& /* exc_value */, + const std::optional& /* traceback */) { scm.exit(); }, "exc_type"_a = nb::none(), "exc_value"_a = nb::none(), "traceback"_a = nb::none()); diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 12aa641f89..4711ef0582 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -86,7 +86,7 @@ auto py_value_and_grad( << argnums[0]; throw std::invalid_argument(msg.str()); } - for (int i = 1; i < argnums.size(); ++i) { + for (int i = 1; i < std::ssize(argnums); ++i) { if (argnums[i] == argnums[i - 1]) { std::ostringstream msg; msg << error_msg_tag << " Duplicate argument index " << argnums[0] @@ -99,7 +99,7 @@ auto py_value_and_grad( return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( nb::args& args, nb::kwargs& kwargs) { // Sanitize the input - if (argnums.size() > 0 && argnums.back() >= args.size()) { + if (argnums.size() > 0 && argnums.back() >= std::ssize(args)) { std::ostringstream msg; msg << error_msg_tag << " Can't compute the gradient of argument index " << argnums.back() << " because the function is called with only " @@ -126,8 +126,8 @@ auto py_value_and_grad( std::vector arrays; std::vector counts(1, 0); std::vector gradient_indices; - for (int i = 0, j = 0; i < args.size(); ++i) { - bool needs_grad = (j < argnums.size() && argnums[j] == i); + for (int i = 0, j = 0; i < std::ssize(args); ++i) { + bool needs_grad = (j < std::ssize(argnums) && argnums[j] == i); auto argsi = tree_flatten(args[i], /* strict = */ needs_grad); if (needs_grad) { auto old_size = gradient_indices.size(); @@ -257,7 +257,7 @@ auto py_value_and_grad( positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]); } else if (argnums.size() > 1) { nb::list grads_; - for (int i = 0; i < argnums.size(); i++) { + for (int i = 0; i < std::ssize(argnums); i++) { grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i])); } positional_grads = nb::tuple(grads_); @@ -366,14 +366,13 @@ auto py_vmap( // able to reconstruct the python tree of extra return values nb::object py_outputs; - auto vmap_fn = - [&fun, &args, &inputs, &py_outputs](const std::vector& a) { - // Call the python function - py_outputs = fun(*tree_unflatten(args, a)); + auto vmap_fn = [&fun, &args, &py_outputs](const std::vector& a) { + // Call the python function + py_outputs = fun(*tree_unflatten(args, a)); - // Flatten the outputs - return tree_flatten(py_outputs, true); - }; + // Flatten the outputs + return tree_flatten(py_outputs, true); + }; auto [trace_inputs, trace_outputs] = mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes); @@ -451,7 +450,7 @@ struct PyCompiledFun { if (nb::isinstance(obj)) { auto l = nb::cast(obj); constants.push_back(list_identifier); - for (int i = 0; i < l.size(); ++i) { + for (int i = 0; i < std::ssize(l); ++i) { recurse(l[i]); } } else if (nb::isinstance(obj)) { diff --git a/python/src/trees.cpp b/python/src/trees.cpp index b75d1187cc..d25540e525 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -6,7 +6,8 @@ template void validate_subtrees(const std::vector& subtrees) { int len = nb::cast(subtrees[0]).size(); for (auto& subtree : subtrees) { - if ((nb::isinstance(subtree) && nb::cast(subtree).size() != len) || + if ((nb::isinstance(subtree) && + std::ssize(nb::cast(subtree)) != len) || nb::isinstance(subtree) || nb::isinstance(subtree)) { throw std::invalid_argument( "[tree_map] Additional input tree is not a valid prefix of the first tree."); @@ -24,8 +25,8 @@ nb::object tree_map( nb::list l; std::vector items(subtrees.size()); validate_subtrees(subtrees); - for (int i = 0; i < nb::cast(subtrees[0]).size(); ++i) { - for (int j = 0; j < subtrees.size(); ++j) { + for (int i = 0; i < std::ssize(nb::cast(subtrees[0])); ++i) { + for (int j = 0; j < std::ssize(subtrees); ++j) { if (nb::isinstance(subtrees[j])) { items[j] = nb::cast(subtrees[j])[i]; } else { @@ -42,7 +43,7 @@ nb::object tree_map( nb::list l; validate_subtrees(subtrees); for (int i = 0; i < len; ++i) { - for (int j = 0; j < subtrees.size(); ++j) { + for (int j = 0; j < std::ssize(subtrees); ++j) { if (nb::isinstance(subtrees[j])) { items[j] = nb::cast(subtrees[j])[i]; } else { @@ -57,7 +58,7 @@ nb::object tree_map( validate_subtrees(subtrees); nb::dict d; for (auto item : nb::cast(subtrees[0])) { - for (int j = 0; j < subtrees.size(); ++j) { + for (int j = 0; j < std::ssize(subtrees); ++j) { if (nb::isinstance(subtrees[j])) { auto subdict = nb::cast(subtrees[j]); if (!subdict.contains(item.first)) { @@ -96,8 +97,8 @@ void tree_visit( if (nb::isinstance(subtrees[0])) { std::vector items(subtrees.size()); validate_subtrees(subtrees); - for (int i = 0; i < nb::cast(subtrees[0]).size(); ++i) { - for (int j = 0; j < subtrees.size(); ++j) { + for (int i = 0; i < std::ssize(nb::cast(subtrees[0])); ++i) { + for (int j = 0; j < std::ssize(subtrees); ++j) { if (nb::isinstance(subtrees[j])) { items[j] = nb::cast(subtrees[j])[i]; } else { @@ -112,7 +113,7 @@ void tree_visit( int len = nb::cast(subtrees[0]).size(); validate_subtrees(subtrees); for (int i = 0; i < len; ++i) { - for (int j = 0; j < subtrees.size(); ++j) { + for (int j = 0; j < std::ssize(subtrees); ++j) { if (nb::isinstance(subtrees[j])) { items[j] = nb::cast(subtrees[j])[i]; } else { @@ -125,7 +126,7 @@ void tree_visit( std::vector items(subtrees.size()); validate_subtrees(subtrees); for (auto item : nb::cast(subtrees[0])) { - for (int j = 0; j < subtrees.size(); ++j) { + for (int j = 0; j < std::ssize(subtrees); ++j) { if (nb::isinstance(subtrees[j])) { auto subdict = nb::cast(subtrees[j]); if (!subdict.contains(item.first)) { @@ -173,13 +174,13 @@ void tree_visit_update( recurse = [&](nb::handle subtree) { if (nb::isinstance(subtree)) { auto l = nb::cast(subtree); - for (int i = 0; i < l.size(); ++i) { + for (int i = 0; i < std::ssize(l); ++i) { l[i] = recurse(l[i]); } return nb::cast(l); } else if (nb::isinstance(subtree)) { nb::list l(subtree); - for (int i = 0; i < l.size(); ++i) { + for (int i = 0; i < std::ssize(l); ++i) { l[i] = recurse(l[i]); } return nb::cast(nb::tuple(l)); @@ -204,7 +205,7 @@ void tree_visit_update( void tree_fill(nb::object& tree, const std::vector& values) { size_t index = 0; tree_visit_update( - tree, [&](nb::handle node) { return nb::cast(values[index++]); }); + tree, [&](nb::handle /* node */) { return nb::cast(values[index++]); }); } // Replace all the arrays from the src values with the dst values in the tree @@ -213,7 +214,7 @@ void tree_replace( const std::vector& src, const std::vector& dst) { std::unordered_map src_to_dst; - for (int i = 0; i < src.size(); ++i) { + for (int i = 0; i < std::ssize(src); ++i) { src_to_dst.insert({src[i].id(), dst[i]}); } tree_visit_update(tree, [&](nb::handle node) { diff --git a/python/src/utils.cpp b/python/src/utils.cpp index 5366e501b2..9927a7d59b 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -57,8 +57,8 @@ std::pair to_arrays( // - If neither is an array convert to arrays but leave their types alone auto is_mlx_array = [](const ScalarOrArray& x) { return std::holds_alternative(x) || - std::holds_alternative(x) && - nb::hasattr(std::get(x).obj, "__mlx_array__"); + (std::holds_alternative(x) && + nb::hasattr(std::get(x).obj, "__mlx_array__")); }; auto get_mlx_array = [](const ScalarOrArray& x) { if (auto px = std::get_if(&x); px) { From 24828b1b2fe01662ac4e74955a753843668d1a87 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 16:55:04 -0700 Subject: [PATCH 30/30] CMakeLists.txt update --- CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 84d4198baf..4777091c72 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,9 +20,13 @@ project( LANGUAGES C CXX VERSION ${MLX_PROJECT_VERSION}) +if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + add_compile_options(-Wall -Wextra) +endif() + # ----------------------------- Setup ----------------------------- set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_INSTALL_MESSAGE NEVER)