Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ project(
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})

if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
add_compile_options(-Wall -Wextra)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved to the # ---- Lib ---- section below.

endif()

# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
Expand Down
6 changes: 5 additions & 1 deletion examples/cpp/tutorial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ void array_basics() {
// Get the value out of it:
auto s = x.item<float>();
assert(s == 1.0);
(void)s;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++17 has a [[maybe_unused]] for this purpose.
https://en.cppreference.com/w/cpp/language/attributes/maybe_unused.html


// Scalars have a size of 1:
size_t size = x.size();
int64_t size = x.size();
assert(size == 1);
(void)size;

// Scalars have 0 dimensions:
int ndim = x.ndim();
assert(ndim == 0);
(void)ndim;

// The shape should be an empty vector:
auto shape = x.shape();
Expand All @@ -30,6 +33,7 @@ void array_basics() {
// The datatype should be float32:
auto dtype = x.dtype();
assert(dtype == mx::float32);
(void)dtype;

// Specify the dtype when constructing the array:
x = mx::array(1, mx::int32);
Expand Down
15 changes: 8 additions & 7 deletions mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ std::vector<array> array::make_arrays(
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (size_t i = 0; i < shapes.size(); ++i) {
for (int i = 0; i < std::ssize(shapes); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
}
// For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
for (int i = 0; i < std::ssize(outputs); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
Expand Down Expand Up @@ -145,8 +145,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
auto max_dim = std::max_element(shape().begin(), shape().end());
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
auto max_dim =
static_cast<int64_t>(*std::max_element(shape().begin(), shape().end()));
array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim;
}

void array::set_data(
Expand Down Expand Up @@ -192,7 +193,7 @@ array::~array() {
}

// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
if (auto n = std::ssize(siblings()); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
Expand Down Expand Up @@ -274,7 +275,7 @@ array::ArrayDesc::~ArrayDesc() {
ad.inputs.clear();
for (auto& [_, a] : input_map) {
bool is_deletable =
(a.array_desc_.use_count() <= a.siblings().size() + 1);
(a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1);
// An array with siblings is deletable only if all of its siblings
// are deletable
for (auto& s : a.siblings()) {
Expand All @@ -283,7 +284,7 @@ array::ArrayDesc::~ArrayDesc() {
}
int is_input = (input_map.find(s.id()) != input_map.end());
is_deletable &=
s.array_desc_.use_count() <= a.siblings().size() + is_input;
s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input;
}
if (is_deletable) {
for_deletion.push_back(std::move(a.array_desc_));
Expand Down
14 changes: 7 additions & 7 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,22 @@ class array {
}

/** The size of the array's datatype in bytes. */
size_t itemsize() const {
int itemsize() const {
return size_of(dtype());
}

/** The number of elements in the array. */
size_t size() const {
int64_t size() const {
return array_desc_->size;
}

/** The number of bytes in the array. */
size_t nbytes() const {
int64_t nbytes() const {
return size() * itemsize();
}

/** The number of dimensions of the array. */
size_t ndim() const {
int ndim() const {
return array_desc_->shape.size();
}

Expand Down Expand Up @@ -329,7 +329,7 @@ class array {
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
size_t data_size() const {
int64_t data_size() const {
return array_desc_->data_size;
}

Expand All @@ -340,7 +340,7 @@ class array {
return array_desc_->data->buffer;
}

size_t buffer_size() const {
int64_t buffer_size() const {
return allocator::allocator().size(buffer());
}

Expand Down Expand Up @@ -530,7 +530,7 @@ array::array(
Shape shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) {
if (std::ssize(data) != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
}
Expand Down
23 changes: 12 additions & 11 deletions mlx/backend/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {

// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
int64_t r = 1, c = 1;
for (int i = std::ssize(strides_) - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
Expand Down Expand Up @@ -60,7 +60,8 @@ void CustomTransforms::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
for (int i = 0, j = std::ssize(inputs) - std::ssize(outputs);
i < std::ssize(outputs);
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
Expand All @@ -70,7 +71,7 @@ void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
for (int i = 0; i < std::ssize(outputs); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
Expand Down Expand Up @@ -206,11 +207,11 @@ void Split::eval(

auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
int64_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
int64_t data_size = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
Expand Down Expand Up @@ -240,7 +241,7 @@ void Split::eval(

std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
for (int i = 0; i < std::ssize(indices); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
Expand All @@ -254,7 +255,7 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < axes_.size() && i == axes_[j]) {
if (j < std::ssize(axes_) && i == axes_[j]) {
j++;
} else {
strides.push_back(in.strides(i));
Expand All @@ -272,7 +273,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Strides out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
for (int ax = 0; ax < std::ssize(axes_); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}

Expand Down
16 changes: 8 additions & 8 deletions mlx/backend/common/compiled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void compiled_allocate_outputs(
Strides strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
Expand All @@ -138,7 +138,7 @@ void compiled_allocate_outputs(
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
for (; o < std::ssize(outputs); ++o) {
outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()),
data_size,
Expand All @@ -147,7 +147,7 @@ void compiled_allocate_outputs(
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
Expand All @@ -162,7 +162,7 @@ void compiled_allocate_outputs(
o++;
}
}
for (; o < outputs.size(); ++o) {
for (; o < std::ssize(outputs); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
}
}
Expand Down Expand Up @@ -193,15 +193,15 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

// Broadcast the inputs to the output shape.
Strides xstrides;
size_t j = 0;
int j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
for (int i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
Expand All @@ -224,13 +224,13 @@ bool compiled_use_large_index(
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
size_t max_size = 0;
int64_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
size_t max_size = 0;
int64_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/common/load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {

namespace mlx::core {

void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
void Load::eval_cpu(const std::vector<array>& /* inputs */, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/common/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(

ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
if (x.size() == x.data_size() && std::ssize(axes) == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
Expand All @@ -38,7 +38,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Merge consecutive axes
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
for (int i = 1; i < std::ssize(axes); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
Expand Down
6 changes: 3 additions & 3 deletions mlx/backend/common/slicing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ std::tuple<int64_t, Strides> prepare_slice(
void shared_buffer_slice(
const array& in,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
int64_t data_offset,
int64_t data_size,
array& out) {
// Compute row/col contiguity
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
Expand Down Expand Up @@ -61,7 +61,7 @@ void slice(
if (data_end < 0) {
data_end += in.data_size();
}
size_t data_size = (data_end - data_offset);
int64_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/common/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
if (shape[0] != 1) {
to_collapse.push_back(0);
}
size_t size = shape[0];
int64_t size = shape[0];
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
Expand Down Expand Up @@ -64,7 +64,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
current_shape *= shape[to_collapse[k]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
for (int j = 0; j < std::ssize(strides); j++) {
const auto& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct ContiguousIterator {
};

inline auto check_contiguity(const Shape& shape, const Strides& strides) {
size_t no_broadcast_data_size = 1;
int64_t no_broadcast_data_size = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
bool is_row_contiguous = true;
Expand All @@ -183,7 +183,7 @@ inline auto check_contiguity(const Shape& shape, const Strides& strides) {
}

inline bool is_donatable(const array& in, const array& out) {
constexpr size_t donation_extra = 16384;
constexpr int64_t donation_extra = 16384;

return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra;
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cpu/arange.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace mlx::core {
namespace {

template <typename T>
void arange(T start, T next, array& out, size_t size, Stream stream) {
void arange(T start, T next, array& out, int64_t size, Stream stream) {
auto ptr = out.data<T>();
auto step_size = next - start;
auto& encoder = cpu::get_command_encoder(stream);
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/cpu/arg_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();

for (uint32_t i = 0; i < out.size(); ++i) {
for (int64_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
}
out_ptr[i] = ind_v;
Expand Down
Loading
Loading