Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions include/axono/core/ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// axono/core/ops.h
#pragma once

#include <pybind11/pybind11.h>
#include <functional>
#include <unordered_map>
#include <string>

namespace axono {
namespace core {

using OpFunction = std::function<pybind11::object(const pybind11::args&)>;

class OpRegistry {
public:
static OpRegistry& instance() {
static OpRegistry registry;
return registry;
}

void register_op(const std::string& name, OpFunction func) {
ops_[name] = std::move(func);
}

const OpFunction& get_op(const std::string& name) const {
auto it = ops_.find(name);
if (it == ops_.end()) {
throw std::runtime_error("算子 " + name + " 不存在。");
}
return it->second;
}

void bind_all(pybind11::module& m) {
for (const auto& [name, func] : ops_) {
m.def(name.c_str(), [func](const pybind11::args& args) {
return func(args);
});
}
}

private:
OpRegistry() = default;
std::unordered_map<std::string, OpFunction> ops_;
};

#define REGISTER_OP(name) \
struct RegisterOp_##name { \
RegisterOp_##name() { \
axono::core::OpRegistry::instance().register_op( \
#name, [](const pybind11::args& args) { \
return op_impl_##name(args); \
}); \
} \
}; \
static RegisterOp_##name register_op_##name; \
pybind11::object op_impl_##name(const pybind11::args& args)
} // namespace core
} // namespace axono
117 changes: 60 additions & 57 deletions include/axono/pybind/compute/operators/add.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,79 +2,82 @@

namespace py = pybind11;

#include "axono/core/ops.h"

#ifdef COMPILED_WITH_CUDA
#include "axono/compute/cuda/operators/add.h"
#endif
#include "axono/compute/cpu/operators/add.h"

void init_add_operations(py::module &m) {
m.def(
"add",
[](const axono::core::Tensor &a, const axono::core::Tensor &b) {
axono::core::Context ctx;
axono::core::Tensor result =
axono::core::Tensor(a.dtype(), a.shape(), a.device());
namespace axono {
namespace compute {
namespace operators {

axono::core::Status status;
if (a.is_cuda()) {
py::object op_impl_add(const py::args& args);
py::object op_impl_add_scalar(const py::args& args);

REGISTER_OP(add) {
if (args.size() != 2) {
throw std::runtime_error("执行 add 需要传入 2 个 Tensor 喵~");
}
auto& a = pybind11::cast<core::Tensor&>(args[0]);
auto& b = pybind11::cast<core::Tensor&>(args[1]);
core::Context ctx;
core::Tensor result = core::Tensor(a.dtype(), a.shape(), a.device());
core::Status status;
if (a.is_cuda()) {
#ifdef COMPILED_WITH_CUDA
status = axono::compute::cuda::operators::Add(ctx, a, b, result);
status = cuda::operators::Add(ctx, a, b, result);
#endif
} else {
status = axono::compute::cpu::operators::Add(ctx, a, b, result);
}
if (status != axono::core::Status::OK) {
throw std::runtime_error(
"喵!计算矩阵加法的时候出现问题啦,错误代码:" +
std::to_string(static_cast<int>(status)));
}
} else {
status = cpu::operators::Add(ctx, a, b, result);
}
if (status != core::Status::OK)
throw std::runtime_error("执行 add 时出现问题,错误代码:" + std::to_string(static_cast<int>(status)));

return result;
},
"Element-wise addition of two tensors", py::arg("a"), py::arg("b"));

m.def(
"add_scalar",
[](const axono::core::Tensor &a, py::object scalar) {
axono::core::Context ctx;
axono::core::Tensor result;
axono::core::Status status;
return pybind11::cast(result);
}

// 将 Python 标量转换为 C++ 数据
if (a.dtype() == axono::core::DataType::FLOAT32) {
float value = scalar.cast<float>();
if (a.is_cuda()) {
REGISTER_OP(add_scalar) {
if (args.size() != 2) {
throw std::runtime_error("执行 add 需要传入 1 个 Tensor, 1 个 Scalar 喵~");
}
auto& a = pybind11::cast<core::Tensor&>(args[0]);
py::object scalar = pybind11::cast<py::object>(args[1]);
core::Context ctx;
core::Tensor result;
core::Status status;
if (a.dtype() == core::DataType::FLOAT32) {
float value = scalar.cast<float>();
if (a.is_cuda()){
#ifdef COMPILED_WITH_CUDA
status = axono::compute::cuda::operators::AddScalar(
ctx, a, &value, sizeof(float), result);
status = cuda::operators::AddScalar(ctx, a, &value, sizeof(float), result);
#endif
} else {
status = axono::compute::cpu::operators::AddScalar(
ctx, a, &value, sizeof(float), result);
}
} else {
status = cpu::operators::AddScalar(ctx, a, &value, sizeof(float), result);
}
if (status != axono::core::Status::OK) {
throw std::runtime_error("Add scalar operation failed");
} else if (a.dtype() == axono::core::DataType::INT32) {
int32_t value = scalar.cast<int32_t>();
if (a.is_cuda()) {
}
if (status != core::Status::OK) {
throw std::runtime_error("执行 add_scalar 的时候出现问题,错误代码:" + std::to_string(static_cast<int>(status)));
} else if (a.dtype() == core::DataType::INT32) {
int32_t value = scalar.cast<int32_t>();
if (a.is_cuda()) {
#ifdef COMPILED_WITH_CUDA
status = axono::compute::cuda::operators::AddScalar(
ctx, a, &value, sizeof(int32_t), result);
status = cuda::operators::AddScalar(ctx, a, &value, sizeof(int32_t), result);
#endif
} else {
status = axono::compute::cpu::operators::AddScalar(
ctx, a, &value, sizeof(int32_t), result);
}

if (status != axono::core::Status::OK) {
throw std::runtime_error("喵!Add 操作出现了一些问题~");
}
} else {
throw std::runtime_error("喵!当前类型不支持执行Add操作喵~");
status = cpu::operators::AddScalar(ctx, a, &value, sizeof(int32_t), result);
}

return result;
},
"Add scalar to tensor", py::arg("a"), py::arg("scalar"));
if (status != core::Status::OK)
throw std::runtime_error("执行 add_scalar 的时候出现问题,错误代码:" + std::to_string(static_cast<int>(status)));
} else {
throw std::runtime_error("当前类型不支持执行 add_scalar 操作喵~");
}

return pybind11::cast(result);
}

}
}
}
57 changes: 32 additions & 25 deletions include/axono/pybind/compute/operators/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,38 @@ namespace py = pybind11;
#endif
#include "axono/compute/cpu/operators/matmul.h"

void init_matmul_operations(py::module &m) {
m.def(
"matmul",
[](const axono::core::Tensor &a, const axono::core::Tensor &b) {
axono::core::Context ctx;
axono::core::Tensor result;
axono::core::Status status;

if (a.is_cuda()) {
namespace axono {
namespace compute {
namespace operators {

py::object op_impl_matmul(const py::args& args);

REGISTER_OP(matmul) {
if (args.size() != 2) {
throw std::runtime_error("执行 add 需要传入 2 个 Tensor 喵~");
}
auto& a = pybind11::cast<core::Tensor&>(args[0]);
auto& b = pybind11::cast<core::Tensor&>(args[1]);
core::Context ctx;
core::Tensor result;
core::Status status;

if (a.is_cuda()) {
#ifdef COMPILED_WITH_CUDA
size_t m = a.shape()[0];
size_t n = b.shape()[1];
auto result = axono::core::Tensor(
a.dtype(), std::vector<size_t>{m, n}, a.device());
status = axono::compute::cuda::operators::MatMul(ctx, a, b, result);
return result;
size_t m = a.shape()[0];
size_t n = b.shape()[1];
auto result = core::Tensor(a.dtype(), std::vector<size_t>{m, n}, a.device());
status = cuda::operators::MatMul(ctx, a, b, result);
#endif
} else {
status = axono::compute::cpu::operators::MatMul(ctx, a, b, result);
}
if (status != axono::core::Status::OK) {
throw std::runtime_error("喵!Matmul 操作 出现错误!");
}

return result;
},
"Matrix multiplication of two tensors", py::arg("a"), py::arg("b"));
} else {
status = compute::cpu::operators::MatMul(ctx, a, b, result);
}
if (status != core::Status::OK)
throw std::runtime_error("执行 Matmul 时出现问题,错误代码:" + std::to_string(static_cast<int>(status)));

return pybind11::cast(result);
}

}
}
}
90 changes: 50 additions & 40 deletions include/axono/pybind/compute/ops/relu.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,55 +3,65 @@
namespace py = pybind11;

#include "axono/core/tensor.h"
#include "axono/core/ops.h"

#ifdef COMPILED_WITH_CUDA
#include "axono/compute/cuda/ops/relu.h"
#endif
#include "axono/compute/cpu/ops/relu.h"

void init_relu_operations(py::module &m) {
m.def(
"relu",
[](const axono::core::Tensor &input) {
axono::core::Context ctx;
axono::core::Tensor output =
axono::core::Tensor(input.dtype(), input.shape(), input.device());
namespace axono {
namespace compute {
namespace ops {

axono::core::Status status;
if (input.is_cuda()) {
py::object op_impl_relu(const py::args& args);
py::object op_impl_relu_(const py::args& args);

REGISTER_OP(relu) {
core::Context ctx;
core::Tensor result;
core::Status status;
if (args.size() != 1) {
throw std::runtime_error("执行 add 需要传入 1 个 Tensor 喵~");
}

auto& input = pybind11::cast<core::Tensor&>(args[0]);
core::Tensor output(input.dtype(), input.shape(), input.device());

if (input.is_cuda()) {
#ifdef COMPILED_WITH_CUDA
status = axono::compute::cuda::ops::Relu(ctx, input, output);
status = cuda::ops::Relu(ctx, input, output);
#endif
} else {
status = axono::compute::cpu::ops::Relu(ctx, input, output);
}

if (status != axono::core::Status::OK) {
throw std::runtime_error("喵!ReLU计算时发生错误,错误代码: " +
std::to_string(static_cast<int>(status)));
}
return output;
},
"ReLU activation function", py::arg("input"),
py::return_value_policy::move),

m.def(
"relu_",
[](axono::core::Tensor &tensor) {
axono::core::Context ctx;
axono::core::Status status;
if (tensor.is_cuda()) {
} else {
status = cpu::ops::Relu(ctx, input, output);
}
if (status != core::Status::OK)
throw std::runtime_error("执行 ReLU 时出现问题,错误代码:" + std::to_string(static_cast<int>(status)));

return pybind11::cast(output);
}
REGISTER_OP(relu_) {
core::Context ctx;
core::Tensor result;
core::Status status;
if (args.size() != 1) {
throw std::runtime_error("执行 add 需要传入 1 个 Tensor 喵~");
}

auto& tensor = pybind11::cast<core::Tensor&>(args[0]);

if (tensor.is_cuda()) {
#ifdef COMPILED_WITH_CUDA
status = axono::compute::cuda::ops::ReluInplace(ctx, tensor);
status = cuda::ops::ReluInplace(ctx, tensor);
#endif
} else {
status = axono::compute::cpu::ops::ReluInplace(ctx, tensor);
}
if (status != axono::core::Status::OK) {
throw std::runtime_error("喵!InplaceReLU 出现错误!");
}

return tensor;
},
"Inplace ReLU activation function", py::arg("tensor"));
} else {
status = cpu::ops::ReluInplace(ctx, tensor);
}
if (status != core::Status::OK)
throw std::runtime_error("执行 ReLU 时出现问题,错误代码:" + std::to_string(static_cast<int>(status)));

return pybind11::cast(tensor);
}
}
}
}
5 changes: 2 additions & 3 deletions python/src/pybind11_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "axono/pybind/compute/ops/relu.h"
#include "axono/pybind/core/tensor.h"
#include "axono/pybind/core/module.h"
#include "axono/core/ops.h"

namespace py = pybind11;

Expand Down Expand Up @@ -37,7 +38,5 @@ PYBIND11_MODULE(libaxono, m) {
// 初始化 Tensor
init_tensor(m);
init_module(m);
init_matmul_operations(m);
init_add_operations(m);
init_relu_operations(m);
axono::core::OpRegistry::instance().bind_all(m);
}
Loading