diff --git a/src/backend_scalar_c.cpp b/src/backend_scalar_c.cpp index 81a4dd5..06db7e6 100644 --- a/src/backend_scalar_c.cpp +++ b/src/backend_scalar_c.cpp @@ -17,89 +17,86 @@ struct LowerCtx int indentation; }; -static void Lower_ScalarC(LowerCtx &ctx, const LoadIntImmediateInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const LoadIntImmediateInsn &i, size_t iinsn, NumericDataType dtype) { std::fprintf(ctx.file, "%*sint64_t v%zu = %" PRIi64 ";\n", ctx.indentation, " ", iinsn, i.value); } -static void Lower_ScalarC(LowerCtx &ctx, const IntArithmeticInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const IntArithmeticInsn &i, size_t iinsn, NumericDataType dtype) { std::fprintf(ctx.file, "%*sint64_t v%zu = v%zu %c v%zu;\n", ctx.indentation, " ", iinsn, i.x, (char)i.op, i.y); } -static void Lower_ScalarC(LowerCtx &ctx, const BeginLoopInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const BeginLoopInsn &i, size_t iinsn, NumericDataType dtype) { std::fprintf(ctx.file, "%*sfor(int64_t v%zu = 0; v%zu < %zd; v%zu++)\n%*s{\n", ctx.indentation, " ", iinsn, iinsn, i.range, iinsn, ctx.indentation, " "); ctx.indentation += 4; } -static void Lower_ScalarC(LowerCtx &ctx, const EndLoopInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const EndLoopInsn &i, size_t iinsn, NumericDataType dtype) { ctx.indentation -= 4; std::fprintf(ctx.file, "%*s}\n", ctx.indentation, " "); } -static void Lower_ScalarC(LowerCtx &ctx, const LoadInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const LoadInsn &i, size_t iinsn, NumericDataType dtype) { - std::fprintf(ctx.file, "%*sfloat v%zu = i%zu[v%zu];\n", - ctx.indentation, " ", iinsn, i.input, i.idx); + std::fprintf(ctx.file, "%*s%s v%zu = i%zu[v%zu];\n", + ctx.indentation, " ", getCDatatype(dtype).c_str(), iinsn, i.input, i.idx); } -static void Lower_ScalarC(LowerCtx &ctx, const StoreInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const StoreInsn &i, size_t iinsn, NumericDataType dtype) { std::fprintf(ctx.file, "%*soutput[v%zu] = v%zu;\n", ctx.indentation, " ", i.offset, i.value); } -static void Lower_ScalarC(LowerCtx &ctx, const LoadImmediateInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const LoadImmediateInsn &i, size_t iinsn, NumericDataType dtype) { - std::fprintf(ctx.file, "%*sfloat v%zu = %f;\n", ctx.indentation, " ", iinsn, i.value); + std::fprintf(ctx.file, "%*s%s v%zu = %f;\n", ctx.indentation, " ", getCDatatype(dtype).c_str(), iinsn, i.value); } -static void Lower_ScalarC(LowerCtx &ctx, const UnaryInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const UnaryInsn &i, size_t iinsn, NumericDataType dtype) { - auto op_str = i.type == UnaryOpType::EXP ? "exp" - : i.type == UnaryOpType::LOG ? "log" - : i.type == UnaryOpType::SIN ? "sin" - : i.type == UnaryOpType::SQRT ? "sqrtf" - : "INVALID"; - std::fprintf(ctx.file, "%*sfloat v%zu = %s(v%zu);\n", - ctx.indentation, " ", iinsn, op_str, i.x); + auto op_str = i.type == UnaryOpType::EXP ? "exp" + : i.type == UnaryOpType::LOG ? "log" + : i.type == UnaryOpType::SIN ? "sin" + : i.type == UnaryOpType::SQRT ? "sqrtf" + : "INVALID"; + std::fprintf(ctx.file, "%*s%s v%zu = %s(v%zu);\n", + ctx.indentation, " ", getCDatatype(dtype).c_str(), iinsn, op_str, i.x); } -static void Lower_ScalarC(LowerCtx &ctx, const BinaryInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const BinaryInsn &i, size_t iinsn, NumericDataType dtype) { - if(i.type == BinaryOpType::ADD - || i.type == BinaryOpType::SUB - || i.type == BinaryOpType::MUL - || i.type == BinaryOpType::DIV - || i.type == BinaryOpType::CMP) + + if (i.type == BinaryOpType::ADD || i.type == BinaryOpType::SUB || i.type == BinaryOpType::MUL || i.type == BinaryOpType::DIV || i.type == BinaryOpType::CMP) { - auto op_str = i.type == BinaryOpType::ADD ? "+" - : i.type == BinaryOpType::SUB ? "-" - : i.type == BinaryOpType::MUL ? "*" - : i.type == BinaryOpType::DIV ? "/" - : "=="; - - std::fprintf(ctx.file, "%*sfloat v%zu = (float)(v%zu %s v%zu);\n", - ctx.indentation, " ", iinsn, i.x, op_str, i.y); + auto op_str = i.type == BinaryOpType::ADD ? "+" + : i.type == BinaryOpType::SUB ? "-" + : i.type == BinaryOpType::MUL ? "*" + : i.type == BinaryOpType::DIV ? "/" + : "=="; + + std::fprintf(ctx.file, "%*s%s v%zu = (%s)(v%zu %s v%zu);\n", + ctx.indentation, " ", getCDatatype(dtype).c_str(), iinsn, getCDatatype(dtype).c_str(), i.x, op_str, i.y); } - else if(i.type == BinaryOpType::MAX) + else if (i.type == BinaryOpType::MAX) { - std::fprintf(ctx.file, "%*sfloat v%zu = v%zu > v%zu ? v%zu : v%zu;\n", - ctx.indentation, " ", iinsn, i.x, i.y, i.x, i.y); + std::fprintf(ctx.file, "%*s%s v%zu = v%zu > v%zu ? v%zu : v%zu;\n", + ctx.indentation, " ", getCDatatype(dtype).c_str(), iinsn, i.x, i.y, i.x, i.y); } else { - std::fprintf(ctx.file, "%*sfloat v%zu = pow(v%zu, v%zu);\n", - ctx.indentation, " ", iinsn, i.x, i.y); + std::fprintf(ctx.file, "%*s%s v%zu = pow(v%zu, v%zu);\n", + ctx.indentation, " ", getCDatatype(dtype).c_str(), iinsn, i.x, i.y); } } -static void Lower_ScalarC(LowerCtx &ctx, const AccumulateInsn &i, size_t iinsn) +static void Lower_ScalarC(LowerCtx &ctx, const AccumulateInsn &i, size_t iinsn, NumericDataType dtype) { - if(i.type == ReduceOpType::MAX) + if (i.type == ReduceOpType::MAX) std::fprintf(ctx.file, "%*sv%zu = v%zu > v%zu ? v%zu : v%zu;\n", ctx.indentation, " ", i.accumulator, i.accumulator, i.x, i.accumulator, i.x); @@ -107,16 +104,17 @@ static void Lower_ScalarC(LowerCtx &ctx, const AccumulateInsn &i, size_t iinsn) std::fprintf(ctx.file, "%*sv%zu += v%zu;\n", ctx.indentation, " ", i.accumulator, i.x); } -static void Lower_ScalarC(LowerCtx &ctx, const FunctionBuilder &fn, size_t ifn) +static void Lower_ScalarC(LowerCtx &ctx, const FunctionBuilder &fn, size_t ifn, NumericDataType dtype) { std::fprintf(ctx.file, "static void %s_%zu(\n", ctx.prefix, ifn); - for(size_t i = 0; i < fn.inputs.size(); i++) - std::fprintf(ctx.file, " const float *i%zu,\n", i); - std::fprintf(ctx.file, " float *output)\n{\n"); + for (size_t i = 0; i < fn.inputs.size(); i++) + std::fprintf(ctx.file, " const %s *i%zu,\n", getCDatatype(dtype).c_str(), i); + std::fprintf(ctx.file, " %s *output)\n{\n", getCDatatype(dtype).c_str()); ctx.indentation = 4; - for(size_t i = 0; i < fn.insns.size(); i++) + for (size_t i = 0; i < fn.insns.size(); i++) { - std::visit([&](auto &&insn) { Lower_ScalarC(ctx, insn, i); }, fn.insns[i]); + std::visit([&](auto &&insn) + { Lower_ScalarC(ctx, insn, i, dtype); }, fn.insns[i]); } std::fprintf(ctx.file, "}\n\n"); } @@ -127,11 +125,11 @@ static void GenerateMain(const Program &program, LowerCtx &ctx) std::fprintf(ctx.file, "#if __linux__\n"); std::fprintf(ctx.file, " feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW);\n"); std::fprintf(ctx.file, "#endif\n"); - for(size_t ifn = 0; ifn < program.functions.size(); ifn++) + for (size_t ifn = 0; ifn < program.functions.size(); ifn++) { const FunctionBuilder &fn = program.functions[ifn]; std::fprintf(ctx.file, " %s_%zu(\n", ctx.prefix, ifn); - for(size_t iinput = 0; iinput < fn.inputs.size(); iinput++) + for (size_t iinput = 0; iinput < fn.inputs.size(); iinput++) std::fprintf(ctx.file, " buffers[%zu],\n", fn.inputs[iinput]); std::fprintf(ctx.file, " buffers[%zu]);\n\n", fn.output_buffer); } @@ -146,7 +144,7 @@ static std::pair CompileAndLoad(const std::filesystem::path obj_path.replace_extension(".so"); std::string command = "cc " + - source_path.string() + + source_path.string() + " -o " + obj_path.string() + " -Ofast -fPIC -shared -lm -march=native -mtune=native"; @@ -154,37 +152,37 @@ static std::pair CompileAndLoad(const std::filesystem::path // std::printf("Compiling with: %s\n", command.c_str()); void *handle = dlopen(obj_path.c_str(), RTLD_NOW | RTLD_LOCAL); - if(!handle) + if (!handle) throw std::runtime_error(dlerror()); dlerror(); // Clear error conditions auto main_fn = reinterpret_cast(dlsym(handle, "gigagrad_main")); - if(!main_fn) + if (!main_fn) { char *err = dlerror(); - if(!err) + if (!err) throw std::runtime_error("Symbol gigagrad_main is NULL, which is unexpected"); else throw std::runtime_error(err); } - return { main_fn, handle }; + return {main_fn, handle}; } -static std::pair Lower_ScalarC(const char *prefix, const Program &program) +static std::pair Lower_ScalarC(const char *prefix, const Program &program, NumericDataType dtype) { auto file_name = std::filesystem::temp_directory_path() / prefix; file_name += ".c"; std::printf("FILE: %s\n", file_name.c_str()); FILE *file = std::fopen(file_name.c_str(), "w+"); - if(!file) + if (!file) throw std::system_error(errno, std::generic_category()); - LowerCtx ctx = { prefix, file, 0 }; + LowerCtx ctx = {prefix, file, 0}; std::fprintf(file, "#define _GNU_SOURCE\n#include \n"); std::fprintf(file, "#include \n#include \n\n"); - for(size_t ifn = 0; ifn < program.functions.size(); ifn++) - ::Lower_ScalarC(ctx, program.functions[ifn], ifn); + for (size_t ifn = 0; ifn < program.functions.size(); ifn++) + ::Lower_ScalarC(ctx, program.functions[ifn], ifn, dtype); GenerateMain(program, ctx); std::fclose(file); @@ -194,12 +192,12 @@ static std::pair Lower_ScalarC(const char *prefix, const Pr BackendScalarC::~BackendScalarC() { dlclose(this->handle); - for(ssize_t ibuff = 0; ibuff < std::ssize(this->program.buffers); ibuff++) + for (ssize_t ibuff = 0; ibuff < std::ssize(this->program.buffers); ibuff++) { auto &desc = this->program.buffers[ibuff]; - if(!std::holds_alternative(desc.id)) + if (!std::holds_alternative(desc.id)) { - delete [] reinterpret_cast(this->buffers[ibuff]); + delete[] reinterpret_cast(this->buffers[ibuff]); } } } @@ -207,7 +205,7 @@ BackendScalarC::~BackendScalarC() void BackendScalarC::LowerProgram(Program &&program) { this->program = std::move(program); - auto [eval_fn, handle] = Lower_ScalarC("gg_scalar", this->program); + auto [eval_fn, handle] = Lower_ScalarC("gg_scalar", this->program, NumericDataType::FLOAT32); this->eval_fn = eval_fn; this->handle = handle; } @@ -215,10 +213,10 @@ void BackendScalarC::LowerProgram(Program &&program) void *BackendScalarC::InitBuffers() { this->buffers.reserve(this->program.buffers.size()); - for(ssize_t ibuff = 0; ibuff < std::ssize(this->program.buffers); ibuff++) + for (ssize_t ibuff = 0; ibuff < std::ssize(this->program.buffers); ibuff++) { auto &desc = this->program.buffers[ibuff]; - if(std::holds_alternative(desc.id)) + if (std::holds_alternative(desc.id)) { GraphNodeHandle tensor = std::get(desc.id); this->buffers.push_back(reinterpret_cast(tensor.data())); @@ -239,10 +237,10 @@ void *BackendScalarC::GetBuffer(size_t idx) void BackendScalarC::Execute() { - for(ssize_t ibuff = 0; ibuff < std::ssize(this->program.buffers); ibuff++) + for (ssize_t ibuff = 0; ibuff < std::ssize(this->program.buffers); ibuff++) { auto &desc = this->program.buffers[ibuff]; - if(std::holds_alternative(desc.id)) + if (std::holds_alternative(desc.id)) { GraphNodeHandle tensor = std::get(desc.id); this->buffers[ibuff] = (reinterpret_cast(tensor.data())); diff --git a/src/backend_scalar_c.h b/src/backend_scalar_c.h index 3c09897..429b117 100644 --- a/src/backend_scalar_c.h +++ b/src/backend_scalar_c.h @@ -4,23 +4,23 @@ namespace gigagrad { -namespace codegen -{ + namespace codegen + { -struct BackendScalarC : public Backend -{ - using GraphEvalFn = void (*)(void **); - virtual ~BackendScalarC(); - virtual void LowerProgram(Program &&program); - virtual void *InitBuffers(); - virtual void *GetBuffer(size_t idx); - virtual void Execute(); + struct BackendScalarC : public Backend + { + using GraphEvalFn = void (*)(void **); + virtual ~BackendScalarC(); + virtual void LowerProgram(Program &&program); + virtual void *InitBuffers(); + virtual void *GetBuffer(size_t idx); + virtual void Execute(); - void *handle; - Program program; - std::vector buffers; - GraphEvalFn eval_fn; -}; + void *handle; + Program program; + std::vector buffers; + GraphEvalFn eval_fn; + }; -} + } } diff --git a/src/codegen.cpp b/src/codegen.cpp index 3cce9a1..727bedd 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -7,275 +7,274 @@ namespace gigagrad { -namespace codegen -{ - -size_t CodegenNode(Program &prog, FunctionBuilder &f, GraphNodeHandle node, size_t load_idx, size_t max_seen_size_elts); + namespace codegen + { -size_t CodegenNode( - Program &prog, - FunctionBuilder &f, - GraphNodeHandle node, - const Tensor &t, - size_t load_idx, - size_t max_seen_size_elts) -{ - size_t size_elts = std::accumulate(node.shape().begin(), node.shape().end(), 1, std::multiplies{}); - size_elts = std::max(max_seen_size_elts, size_elts); - size_t buffer_id = prog.AddBuffer(node, size_elts); - auto input = f.Input(buffer_id); - return f.Load(input, load_idx); -} + size_t CodegenNode(Program &prog, FunctionBuilder &f, GraphNodeHandle node, size_t load_idx, size_t max_seen_size_elts); -size_t CodegenNode( - Program &prog, - FunctionBuilder &f, - GraphNodeHandle, - const Immediate &i, - size_t load_idx, - size_t max_seen_size_elts) -{ - return f.Immediate(i.value); -} + size_t CodegenNode( + Program &prog, + FunctionBuilder &f, + GraphNodeHandle node, + const Tensor &t, + size_t load_idx, + size_t max_seen_size_elts) + { + size_t size_elts = std::accumulate(node.shape().begin(), node.shape().end(), 1, std::multiplies{}); + size_elts = std::max(max_seen_size_elts, size_elts); + size_t buffer_id = prog.AddBuffer(node, size_elts); + auto input = f.Input(buffer_id); + return f.Load(input, load_idx); + } -size_t CodegenNode( - Program &prog, - FunctionBuilder &f, - GraphNodeHandle, - const UnaryOp &u, - size_t load_idx, - size_t max_seen_size_elts) -{ - auto x = CodegenNode(prog, f, u.x, load_idx, max_seen_size_elts); - return f.Unary(u.type, x); -} + size_t CodegenNode( + Program &prog, + FunctionBuilder &f, + GraphNodeHandle, + const Immediate &i, + size_t load_idx, + size_t max_seen_size_elts) + { + return f.Immediate(i.value); + } -size_t CodegenNode( - Program &prog, - FunctionBuilder &f, - GraphNodeHandle node, - const BinaryOp &b, - size_t load_idx, - size_t max_seen_size_elts) -{ - const Shape &xshape = b.x.shape(); - const Shape &yshape = b.y.shape(); - const Shape &broadcasted_shape = node.shape(); - const Shape &broadcasted_strides = node.strides(); + size_t CodegenNode( + Program &prog, + FunctionBuilder &f, + GraphNodeHandle node, + const UnaryOp &u, + size_t load_idx, + size_t max_seen_size_elts) + { + auto x = CodegenNode(prog, f, u.x, load_idx, max_seen_size_elts); + return f.Unary(u.type, x, node.dtype()); + } - auto generate_stride_adjustments = - [&f, &broadcasted_shape, &broadcasted_strides, load_idx](const Shape &shape) + size_t CodegenNode( + Program &prog, + FunctionBuilder &f, + GraphNodeHandle node, + const BinaryOp &b, + size_t load_idx, + size_t max_seen_size_elts) { - auto load = load_idx; - for(ssize_t i = 0; i < std::ssize(shape); i++) + const Shape &xshape = b.x.shape(); + const Shape &yshape = b.y.shape(); + const Shape &broadcasted_shape = node.shape(); + const Shape &broadcasted_strides = node.strides(); + + auto generate_stride_adjustments = + [&f, &broadcasted_shape, &broadcasted_strides, load_idx](const Shape &shape) { - if(broadcasted_shape[i] != 1 && shape[i] == 1) + auto load = load_idx; + for (ssize_t i = 0; i < std::ssize(shape); i++) { - auto divisor = f.IntImmediate(broadcasted_strides[i] * broadcasted_shape[i]); - auto modulus = f.IntImmediate(broadcasted_strides[i]); - auto div = f.Arithmetic(load, IntArithmeticInsn::Op::DIV, divisor); - auto mod = f.Arithmetic(load, IntArithmeticInsn::Op::MOD, modulus); - load = f.Arithmetic(div, IntArithmeticInsn::Op::ADD, mod); + if (broadcasted_shape[i] != 1 && shape[i] == 1) + { + auto divisor = f.IntImmediate(broadcasted_strides[i] * broadcasted_shape[i]); + auto modulus = f.IntImmediate(broadcasted_strides[i]); + auto div = f.Arithmetic(load, IntArithmeticInsn::Op::DIV, divisor); + auto mod = f.Arithmetic(load, IntArithmeticInsn::Op::MOD, modulus); + load = f.Arithmetic(div, IntArithmeticInsn::Op::ADD, mod); + } } - } - return load; - }; - - size_t xload = generate_stride_adjustments(xshape); - size_t yload = generate_stride_adjustments(yshape); - auto x = CodegenNode(prog, f, b.x, xload, max_seen_size_elts); - auto y = CodegenNode(prog, f, b.y, yload, max_seen_size_elts); - return f.Binary(b.type, x, y); -} + return load; + }; -size_t CodegenNode( - Program &prog, - FunctionBuilder &old_f, - GraphNodeHandle node, - const ReduceOp &r, - size_t output_load_idx, - size_t max_seen_size_elts) -{ - FunctionBuilder f(node, max_seen_size_elts); + size_t xload = generate_stride_adjustments(xshape); + size_t yload = generate_stride_adjustments(yshape); + auto x = CodegenNode(prog, f, b.x, xload, max_seen_size_elts); + auto y = CodegenNode(prog, f, b.y, yload, max_seen_size_elts); + return f.Binary(b.type, x, y, node.dtype()); + } - std::vector accumulators; - auto reduce_dim = r.dims.begin(); // dims is sorted - auto ioutput_strides = node.strides().begin(); + size_t CodegenNode( + Program &prog, + FunctionBuilder &old_f, + GraphNodeHandle node, + const ReduceOp &r, + size_t output_load_idx, + size_t max_seen_size_elts) + { + FunctionBuilder f(node, max_seen_size_elts); - auto store_idx = f.IntImmediate(0); - auto load_idx = store_idx; + std::vector accumulators; + auto reduce_dim = r.dims.begin(); // dims is sorted + auto ioutput_strides = node.strides().begin(); - const Shape &input_shape = r.x.shape(); - const Shape &input_strides = r.x.strides(); + auto store_idx = f.IntImmediate(0); + auto load_idx = store_idx; - // Generate loops for all of the non-reducing dimensions - for(ssize_t i = 0; i < std::ssize(input_shape); i++) - { - if(reduce_dim == r.dims.end() || i != *reduce_dim) - { - auto loop = f.Loop(input_shape[i], input_strides[i]); - auto input_stride = f.IntImmediate(input_strides[i]); - auto output_stride = f.IntImmediate(*ioutput_strides); - auto mul_input_stride = f.Arithmetic(loop, IntArithmeticInsn::Op::MUL, input_stride); - auto mul_output_stride = f.Arithmetic(loop, IntArithmeticInsn::Op::MUL, output_stride); - load_idx = f.Arithmetic(load_idx, IntArithmeticInsn::Op::ADD, mul_input_stride); - store_idx = f.Arithmetic(store_idx, IntArithmeticInsn::Op::ADD, mul_output_stride); + const Shape &input_shape = r.x.shape(); + const Shape &input_strides = r.x.strides(); - if(!r.keepdim) - ioutput_strides++; - } - else if(i == *reduce_dim) - { - reduce_dim++; - } + // Generate loops for all of the non-reducing dimensions + for (ssize_t i = 0; i < std::ssize(input_shape); i++) + { + if (reduce_dim == r.dims.end() || i != *reduce_dim) + { + auto loop = f.Loop(input_shape[i], input_strides[i]); + auto input_stride = f.IntImmediate(input_strides[i]); + auto output_stride = f.IntImmediate(*ioutput_strides); + auto mul_input_stride = f.Arithmetic(loop, IntArithmeticInsn::Op::MUL, input_stride); + auto mul_output_stride = f.Arithmetic(loop, IntArithmeticInsn::Op::MUL, output_stride); + load_idx = f.Arithmetic(load_idx, IntArithmeticInsn::Op::ADD, mul_input_stride); + store_idx = f.Arithmetic(store_idx, IntArithmeticInsn::Op::ADD, mul_output_stride); - // If keepdim, always advance output_strides because number of input/output - // dimensions matches - if(r.keepdim) - ioutput_strides++; - } - // Generate loops along reduction dimension - for(auto dim : r.dims) - { - accumulators.push_back(f.Immediate(0.0f)); - auto loop = f.Loop(input_shape[dim], input_strides[dim]); - auto stride = f.IntImmediate(input_strides[dim]); - auto mul = f.Arithmetic(loop, IntArithmeticInsn::Op::MUL, stride); - load_idx = f.Arithmetic(load_idx, IntArithmeticInsn::Op::ADD, mul); - } + if (!r.keepdim) + ioutput_strides++; + } + else if (i == *reduce_dim) + { + reduce_dim++; + } - auto to_accumulate = CodegenNode(prog, f, r.x, load_idx, 0); + // If keepdim, always advance output_strides because number of input/output + // dimensions matches + if (r.keepdim) + ioutput_strides++; + } + // Generate loops along reduction dimension + for (auto dim : r.dims) + { + accumulators.push_back(f.Immediate(0.0f)); + auto loop = f.Loop(input_shape[dim], input_strides[dim]); + auto stride = f.IntImmediate(input_strides[dim]); + auto mul = f.Arithmetic(loop, IntArithmeticInsn::Op::MUL, stride); + load_idx = f.Arithmetic(load_idx, IntArithmeticInsn::Op::ADD, mul); + } - ssize_t iaccum = std::ssize(r.dims) - 1; - do - { - f.Accumulate(r.type, accumulators[iaccum], to_accumulate); - to_accumulate = accumulators[iaccum]; - f.EndLoop(); - iaccum--; - } while(iaccum >= 0); - f.Store(store_idx, accumulators[0]); - for(ssize_t i = 0; i < std::ssize(input_shape) - std::ssize(r.dims); i++) - f.EndLoop(); + auto to_accumulate = CodegenNode(prog, f, r.x, load_idx, 0); - prog.PushFunction(std::move(f)); - auto input = old_f.Input(prog.functions.back().output_buffer); - return old_f.Load(input, output_load_idx); -} + ssize_t iaccum = std::ssize(r.dims) - 1; + do + { + f.Accumulate(r.type, accumulators[iaccum], to_accumulate); + to_accumulate = accumulators[iaccum]; + f.EndLoop(); + iaccum--; + } while (iaccum >= 0); + f.Store(store_idx, accumulators[0]); + for (ssize_t i = 0; i < std::ssize(input_shape) - std::ssize(r.dims); i++) + f.EndLoop(); -size_t CodegenNode( - Program &prog, - FunctionBuilder &f, - GraphNodeHandle node, - const ViewOp &v, - size_t load_idx, - size_t max_seen_size) -{ - // TODO: This generates a lot of unnecessary crap right now, ideally this arithmetic - // expression would get accumulated and simplified before being emitted. - const Shape &shape = v.shape; - const Shape &strides = v.strides; - const Shape &output_strides = node.strides(); + prog.PushFunction(std::move(f)); + auto input = old_f.Input(prog.functions.back().output_buffer); + return old_f.Load(input, output_load_idx); + } - auto new_load_idx = f.IntImmediate(0); - for(ssize_t i = std::ssize(shape) - 1; i >= 0; i--) - { - auto output_stride = f.IntImmediate(output_strides[i]); - auto output_shape = f.IntImmediate(shape[i]); - auto input_stride = f.IntImmediate(strides[i]); - auto div = f.Arithmetic(load_idx, IntArithmeticInsn::Op::DIV, output_stride); - auto mod = f.Arithmetic(div, IntArithmeticInsn::Op::MOD, output_shape); - auto mul = f.Arithmetic(mod, IntArithmeticInsn::Op::MUL, input_stride); - new_load_idx = f.Arithmetic(new_load_idx, IntArithmeticInsn::Op::ADD, mul); - } - auto offset = f.IntImmediate(v.offset); - new_load_idx = f.Arithmetic(new_load_idx, IntArithmeticInsn::Op::ADD, offset); - size_t view_size = std::accumulate( - v.shape.begin(), - v.shape.end(), - dim_t{1}, - std::multiplies{}); - view_size -= v.offset; + size_t CodegenNode( + Program &prog, + FunctionBuilder &f, + GraphNodeHandle node, + const ViewOp &v, + size_t load_idx, + size_t max_seen_size) + { + // TODO: This generates a lot of unnecessary crap right now, ideally this arithmetic + // expression would get accumulated and simplified before being emitted. + const Shape &shape = v.shape; + const Shape &strides = v.strides; + const Shape &output_strides = node.strides(); - max_seen_size = std::max(max_seen_size, view_size); - auto x = CodegenNode(prog, f, v.x, new_load_idx, max_seen_size); - return x; -} + auto new_load_idx = f.IntImmediate(0); + for (ssize_t i = std::ssize(shape) - 1; i >= 0; i--) + { + auto output_stride = f.IntImmediate(output_strides[i]); + auto output_shape = f.IntImmediate(shape[i]); + auto input_stride = f.IntImmediate(strides[i]); + auto div = f.Arithmetic(load_idx, IntArithmeticInsn::Op::DIV, output_stride); + auto mod = f.Arithmetic(div, IntArithmeticInsn::Op::MOD, output_shape); + auto mul = f.Arithmetic(mod, IntArithmeticInsn::Op::MUL, input_stride); + new_load_idx = f.Arithmetic(new_load_idx, IntArithmeticInsn::Op::ADD, mul); + } + auto offset = f.IntImmediate(v.offset); + new_load_idx = f.Arithmetic(new_load_idx, IntArithmeticInsn::Op::ADD, offset); + size_t view_size = std::accumulate( + v.shape.begin(), + v.shape.end(), + dim_t{1}, + std::multiplies{}); + view_size -= v.offset; -size_t CodegenNode(Program &prog, FunctionBuilder &f, GraphNodeHandle node, size_t load_idx, size_t max_seen_size_elts) -{ - if(prog.node_function_cache.contains(node.node_idx)) - { - size_t function_id = prog.node_function_cache[node.node_idx]; - size_t buffer_id = prog.functions[function_id].output_buffer; - prog.buffers[buffer_id].size_elts = std::max(prog.buffers[buffer_id].size_elts, max_seen_size_elts); - auto input = f.Input(buffer_id); - return f.Load(input, load_idx); - } - return node->Visit([&](auto &&x) - { - return CodegenNode(prog, f, node, x, load_idx, max_seen_size_elts); - }); -} + max_seen_size = std::max(max_seen_size, view_size); + auto x = CodegenNode(prog, f, v.x, new_load_idx, max_seen_size); + return x; + } -void CodegenNode(Program &prog, GraphNodeHandle node, std::optional output_buffer) -{ - // ReduceOp generates its own loops - if(node->Kind() == GraphNode::Kind::ReduceOp) - { - FunctionBuilder f(node); - CodegenNode(prog, f, node, 0, 0); - } - else - { - FunctionBuilder f(node); - const Shape &shape = node.shape(); - const Shape &strides = node.strides(); - auto load_idx = f.IntImmediate(0); - for(ssize_t i = 0; i < std::ssize(shape); i++) + size_t CodegenNode(Program &prog, FunctionBuilder &f, GraphNodeHandle node, size_t load_idx, size_t max_seen_size_elts) { - auto loop = f.Loop(shape[i], strides[i]); - auto stride = f.IntImmediate(strides[i]); - auto mul = f.Arithmetic(loop, IntArithmeticInsn::Op::MUL, stride); - load_idx = f.Arithmetic(load_idx, IntArithmeticInsn::Op::ADD, mul); + if (prog.node_function_cache.contains(node.node_idx)) + { + size_t function_id = prog.node_function_cache[node.node_idx]; + size_t buffer_id = prog.functions[function_id].output_buffer; + prog.buffers[buffer_id].size_elts = std::max(prog.buffers[buffer_id].size_elts, max_seen_size_elts); + auto input = f.Input(buffer_id); + return f.Load(input, load_idx); + } + return node->Visit([&](auto &&x) + { return CodegenNode(prog, f, node, x, load_idx, max_seen_size_elts); }); } - auto to_store = CodegenNode(prog, f, node, load_idx, 0); - f.Store(load_idx, to_store); - for(ssize_t i = 0; i < std::ssize(shape); i++) - f.EndLoop(); - prog.PushFunction(std::move(f)); - } - if(output_buffer) - { - // This can potentially be a little fragile, but it's simple and easy for now. - // We rely on the fact that PushFunction always adds a new buffer, so if we want - // to remap the last function's output to something else, we can just pop_back(). - prog.buffers.pop_back(); - prog.ChangeOutputBuffer(prog.functions.size() - 1, *output_buffer); - } -} + void CodegenNode(Program &prog, GraphNodeHandle node, std::optional output_buffer) + { + // ReduceOp generates its own loops + if (node->Kind() == GraphNode::Kind::ReduceOp) + { + FunctionBuilder f(node); + CodegenNode(prog, f, node, 0, 0); + } + else + { + FunctionBuilder f(node); + const Shape &shape = node.shape(); + const Shape &strides = node.strides(); + // const NumericDataType &data = node.dtype(); + auto load_idx = f.IntImmediate(0); + for (ssize_t i = 0; i < std::ssize(shape); i++) + { + auto loop = f.Loop(shape[i], strides[i]); + auto stride = f.IntImmediate(strides[i]); + auto mul = f.Arithmetic(loop, IntArithmeticInsn::Op::MUL, stride); + load_idx = f.Arithmetic(load_idx, IntArithmeticInsn::Op::ADD, mul); + } + auto to_store = CodegenNode(prog, f, node, load_idx, 0); + f.Store(load_idx, to_store); + for (ssize_t i = 0; i < std::ssize(shape); i++) + f.EndLoop(); -codegen::Program CodegenNode(GraphNodeHandle node) -{ - codegen::Program result; - codegen::CodegenNode(result, node); - return result; -} + prog.PushFunction(std::move(f)); + } + if (output_buffer) + { + // This can potentially be a little fragile, but it's simple and easy for now. + // We rely on the fact that PushFunction always adds a new buffer, so if we want + // to remap the last function's output to something else, we can just pop_back(). + prog.buffers.pop_back(); + prog.ChangeOutputBuffer(prog.functions.size() - 1, *output_buffer); + } + } -} + codegen::Program CodegenNode(GraphNodeHandle node) + { + codegen::Program result; + codegen::CodegenNode(result, node); + return result; + } -CompiledTensor GraphNodeHandle::Compile(std::unique_ptr backend) const -{ - codegen::Program prog = codegen::CodegenNode(*this); - backend->LowerProgram(std::move(prog)); + } - CompiledTensor result; - result.shape = this->shape(); - result.data = reinterpret_cast(backend->InitBuffers()); - result.backend = std::move(backend); + CompiledTensor GraphNodeHandle::Compile(std::unique_ptr backend) const + { + codegen::Program prog = codegen::CodegenNode(*this); + backend->LowerProgram(std::move(prog)); - return result; -} + CompiledTensor result; + result.shape = this->shape(); + result.data = reinterpret_cast(backend->InitBuffers()); + result.backend = std::move(backend); + + return result; + } } diff --git a/src/codegen.h b/src/codegen.h index 910d3c1..316a98b 100644 --- a/src/codegen.h +++ b/src/codegen.h @@ -9,340 +9,343 @@ namespace gigagrad { -namespace codegen -{ - -struct LoadIntImmediateInsn -{ - int64_t value; - - void Print(size_t iinsn) - { - std::printf("v%zu = %" PRIi64 "\n", iinsn, value); - } -}; - -struct IntArithmeticInsn -{ - enum class Op : char - { - ADD = '+', - SUB = '-', - MUL = '*', - DIV = '/', - MOD = '%', - }; - - Op op; - size_t x; - size_t y; - - void Print(size_t iinsn) - { - std::printf("v%zu = v%zu %c v%zu\n", iinsn, x, (char)op, y); - } -}; - -struct BeginLoopInsn -{ - dim_t range; - dim_t stride; - - void Print(size_t iinsn) - { - std::printf("v%zu = LOOP [0..%zd, %zd]\n", iinsn, range, stride); - } -}; - -struct EndLoopInsn -{ - void Print(size_t) - { - std::printf("END LOOP\n"); - } -}; - -struct LoadInsn -{ - size_t input; - size_t idx; - - void Print(size_t iinsn) - { - std::printf("v%zu = LOAD I%zu[v%zu]\n", iinsn, input, idx); - } -}; - -struct StoreInsn -{ - size_t offset; - size_t value; - - void Print(size_t iinsn) - { - std::printf("Output[v%zu] = v%zu\n", offset, value); - } -}; - -struct LoadImmediateInsn -{ - float value; - - void Print(size_t iinsn) - { - std::printf("v%zu = %f\n", iinsn, value); - } -}; - -struct UnaryInsn -{ - UnaryOpType type; - size_t x; - - void Print(size_t iinsn) - { - auto op_str = type == UnaryOpType::NOP ? "NOP" - : type == UnaryOpType::EXP ? "EXP" - : type == UnaryOpType::LOG ? "LOG" - : type == UnaryOpType::CAST ? "CAST" - : type == UnaryOpType::SIN ? "SIN" - : "INVALID"; - std::printf("v%zu = %s(v%zu)\n", iinsn, op_str, x); - } -}; - -struct BinaryInsn -{ - BinaryOpType type; - size_t x; - size_t y; - - void Print(size_t iinsn) - { - auto op_str = type == BinaryOpType::ADD ? "+" - : type == BinaryOpType::SUB ? "-" - : type == BinaryOpType::MUL ? "*" - : type == BinaryOpType::DIV ? "/" - : type == BinaryOpType::POW ? "^" - : type == BinaryOpType::CMP ? "==" - : type == BinaryOpType::MAX ? "max" - : "INVALID"; - std::printf("v%zu = v%zu %s v%zu\n", iinsn, x, op_str, y); - } -}; - -struct AccumulateInsn -{ - ReduceOpType type; - size_t accumulator; - size_t x; - - void Print(size_t iinsn) - { - auto op_str = type == ReduceOpType::MAX ? "MAX" - : type == ReduceOpType::SUM ? "SUM" - : "INVALID"; - std::printf("v%zu <- %s(v%zu, v%zu)\n", accumulator, op_str, accumulator, x); - } -}; - -using Instruction = std::variant< - LoadIntImmediateInsn, - IntArithmeticInsn, - BeginLoopInsn, - EndLoopInsn, - LoadInsn, - StoreInsn, - LoadImmediateInsn, - UnaryInsn, - BinaryInsn, - AccumulateInsn>; - -struct FunctionBuilder -{ - explicit FunctionBuilder(GraphNodeHandle node, size_t max_seen_size = 1) - : node(node) - { - const Shape &shape = node.shape(); - size_t output_size = std::accumulate( - shape.begin(), - shape.end(), - dim_t{1}, - std::multiplies{}); - - this->output_size = std::max( - max_seen_size, - output_size); - } - - size_t Loop(dim_t range, dim_t stride) - { - insns.emplace_back(BeginLoopInsn{range, stride}); - return insns.size() - 1; - } - - size_t EndLoop() + namespace codegen { - insns.emplace_back(EndLoopInsn{}); - return insns.size() - 1; - } - size_t Input(size_t program_input_idx) - { - // TODO: Make this not O(n), probably using unordered_map from buffer index - // to input index - auto input = std::find(inputs.begin(), inputs.end(), program_input_idx); - if(input == inputs.end()) + struct LoadIntImmediateInsn { - inputs.push_back(program_input_idx); - return inputs.size() - 1; - } - return input - inputs.begin(); - } - - size_t Load(size_t input_idx, size_t load_idx) - { - insns.emplace_back(LoadInsn{input_idx, load_idx}); - return insns.size() - 1; - } + int64_t value; - size_t Store(size_t offset, size_t value) - { - insns.emplace_back(StoreInsn{offset, value}); - return insns.size() - 1; - } + void Print(size_t iinsn) + { + std::printf("v%zu = %" PRIi64 "\n", iinsn, value); + } + }; - size_t Immediate(float value) - { - insns.emplace_back(LoadImmediateInsn{value}); - return insns.size() - 1; - } - - size_t IntImmediate(int64_t value) - { - insns.emplace_back(LoadIntImmediateInsn{value}); - return insns.size() - 1; - } - - size_t Arithmetic(size_t x, IntArithmeticInsn::Op op, size_t y) - { - insns.emplace_back(IntArithmeticInsn{op, x, y}); - return insns.size() - 1; - } - - size_t Unary(UnaryOpType type, size_t x) - { - insns.emplace_back(UnaryInsn{type, x}); - return insns.size() - 1; - } - - size_t Binary(BinaryOpType type, size_t x, size_t y) - { - insns.emplace_back(BinaryInsn{type, x, y}); - return insns.size() - 1; - } + struct IntArithmeticInsn + { + enum class Op : char + { + ADD = '+', + SUB = '-', + MUL = '*', + DIV = '/', + MOD = '%', + }; + + Op op; + size_t x; + size_t y; + + void Print(size_t iinsn) + { + std::printf("v%zu = v%zu %c v%zu\n", iinsn, x, (char)op, y); + } + }; + + struct BeginLoopInsn + { + dim_t range; + dim_t stride; - size_t Accumulate(ReduceOpType type, size_t accumulator, size_t x) - { - insns.emplace_back(AccumulateInsn{type, accumulator, x}); - return insns.size() - 1; - } + void Print(size_t iinsn) + { + std::printf("v%zu = LOOP [0..%zd, %zd]\n", iinsn, range, stride); + } + }; - void Print() - { - for(ssize_t i = 0; i < std::ssize(insns); i++) + struct EndLoopInsn { - std::visit([&](auto &&insn) { insn.Print(i); }, insns[i]); - } - } + void Print(size_t) + { + std::printf("END LOOP\n"); + } + }; - GraphNodeHandle node; // Node that represents the output of the function - size_t output_size; + struct LoadInsn + { + size_t input; + size_t idx; - std::vector insns; - std::vector inputs; // Indices into the program inputs - size_t output_buffer; -}; + void Print(size_t iinsn) + { + std::printf("v%zu = LOAD I%zu[v%zu]\n", iinsn, input, idx); + } + }; -struct BufferDescriptor -{ - std::variant id; // Either a tensor or a function index - size_t size_elts; -}; + struct StoreInsn + { + size_t offset; + size_t value; -struct Program -{ - void PushFunction(FunctionBuilder function) - { - functions.emplace_back(std::move(function)); - functions.back().output_buffer = AddBuffer(functions.size() - 1); - node_function_cache[functions.back().node.node_idx] = functions.size() - 1; - } + void Print(size_t iinsn) + { + std::printf("Output[v%zu] = v%zu\n", offset, value); + } + }; - size_t NumFunctions() - { - return functions.size(); - } + struct LoadImmediateInsn + { + float value; - size_t AddBuffer(GraphNodeHandle t, size_t size_elts) - { - if(t->Kind() != GraphNode::Kind::Tensor) - throw std::domain_error("Cannot AddBuffer on non-tensor"); + void Print(size_t iinsn) + { + std::printf("v%zu = %f\n", iinsn, value); + } + }; - for(size_t iinput = 0; iinput < buffers.size(); iinput++) + struct UnaryInsn { - const auto &buff_id = buffers[iinput].id; - if(std::holds_alternative(buff_id)) - if(std::get(buff_id).node_idx == t.node_idx) - return iinput; - } - buffers.push_back({ t, size_elts }); - return buffers.size() - 1; - } - - size_t AddBuffer(const size_t fn_idx) - { - for(size_t iinput = 0; iinput < buffers.size(); iinput++) + UnaryOpType type; + size_t x; + NumericDataType dtype; + + void Print(size_t iinsn) + { + auto op_str = type == UnaryOpType::NOP ? "NOP" + : type == UnaryOpType::EXP ? "EXP" + : type == UnaryOpType::LOG ? "LOG" + : type == UnaryOpType::CAST ? "CAST" + : type == UnaryOpType::SIN ? "SIN" + : "INVALID"; + std::printf("v%zu = %s(v%zu)\n", iinsn, op_str, x); + } + }; + + struct BinaryInsn { - const auto &buff_id = buffers[iinput].id; - if(std::holds_alternative(buff_id)) - if(std::get(buff_id) == fn_idx) - return iinput; - } - buffers.push_back({ fn_idx, functions[fn_idx].output_size }); - return buffers.size() - 1; - } - - size_t GetOutputBufferForNodeIdx(size_t node_id) - { - size_t function_id = node_function_cache[node_id]; - return functions[function_id].output_buffer; - } - - void ChangeOutputBuffer(size_t fn_idx, size_t new_output_buffer) - { - if(new_output_buffer >= buffers.size()) - throw std::domain_error("Invalid output buffer"); - functions[fn_idx].output_buffer = new_output_buffer; - } - - void Print() - { - for(size_t i = 0; i < functions.size(); i++) + BinaryOpType type; + size_t x; + size_t y; + NumericDataType dtype; + + void Print(size_t iinsn) + { + auto op_str = type == BinaryOpType::ADD ? "+" + : type == BinaryOpType::SUB ? "-" + : type == BinaryOpType::MUL ? "*" + : type == BinaryOpType::DIV ? "/" + : type == BinaryOpType::POW ? "^" + : type == BinaryOpType::CMP ? "==" + : type == BinaryOpType::MAX ? "max" + : "INVALID"; + std::printf("v%zu = v%zu %s v%zu\n", iinsn, x, op_str, y); + } + }; + + struct AccumulateInsn { - std::printf("BEGIN FUNCTION %zu\n", i); - functions[i].Print(); - std::printf("END FUNCTION %zu\n", i); - } - } - - std::unordered_map node_function_cache; - std::vector functions; - std::vector buffers; -}; + ReduceOpType type; + size_t accumulator; + size_t x; + + void Print(size_t iinsn) + { + auto op_str = type == ReduceOpType::MAX ? "MAX" + : type == ReduceOpType::SUM ? "SUM" + : "INVALID"; + std::printf("v%zu <- %s(v%zu, v%zu)\n", accumulator, op_str, accumulator, x); + } + }; + + using Instruction = std::variant< + LoadIntImmediateInsn, + IntArithmeticInsn, + BeginLoopInsn, + EndLoopInsn, + LoadInsn, + StoreInsn, + LoadImmediateInsn, + UnaryInsn, + BinaryInsn, + AccumulateInsn>; + + struct FunctionBuilder + { + explicit FunctionBuilder(GraphNodeHandle node, size_t max_seen_size = 1) + : node(node) + { + const Shape &shape = node.shape(); + size_t output_size = std::accumulate( + shape.begin(), + shape.end(), + dim_t{1}, + std::multiplies{}); + + this->output_size = std::max( + max_seen_size, + output_size); + } + + size_t Loop(dim_t range, dim_t stride) + { + insns.emplace_back(BeginLoopInsn{range, stride}); + return insns.size() - 1; + } + + size_t EndLoop() + { + insns.emplace_back(EndLoopInsn{}); + return insns.size() - 1; + } + + size_t Input(size_t program_input_idx) + { + // TODO: Make this not O(n), probably using unordered_map from buffer index + // to input index + auto input = std::find(inputs.begin(), inputs.end(), program_input_idx); + if (input == inputs.end()) + { + inputs.push_back(program_input_idx); + return inputs.size() - 1; + } + return input - inputs.begin(); + } + + size_t Load(size_t input_idx, size_t load_idx) + { + insns.emplace_back(LoadInsn{input_idx, load_idx}); + return insns.size() - 1; + } + + size_t Store(size_t offset, size_t value) + { + insns.emplace_back(StoreInsn{offset, value}); + return insns.size() - 1; + } + + size_t Immediate(float value) + { + insns.emplace_back(LoadImmediateInsn{value}); + return insns.size() - 1; + } + + size_t IntImmediate(int64_t value) + { + insns.emplace_back(LoadIntImmediateInsn{value}); + return insns.size() - 1; + } + + size_t Arithmetic(size_t x, IntArithmeticInsn::Op op, size_t y) + { + insns.emplace_back(IntArithmeticInsn{op, x, y}); + return insns.size() - 1; + } + + size_t Unary(UnaryOpType type, size_t x, NumericDataType dtype) + { + insns.emplace_back(UnaryInsn{type, x, dtype}); + return insns.size() - 1; + } + + size_t Binary(BinaryOpType type, size_t x, size_t y, NumericDataType dtype) + { + insns.emplace_back(BinaryInsn{type, x, y, dtype}); + return insns.size() - 1; + } + + size_t Accumulate(ReduceOpType type, size_t accumulator, size_t x) + { + insns.emplace_back(AccumulateInsn{type, accumulator, x}); + return insns.size() - 1; + } + + void Print() + { + for (ssize_t i = 0; i < std::ssize(insns); i++) + { + std::visit([&](auto &&insn) + { insn.Print(i); }, insns[i]); + } + } + + GraphNodeHandle node; // Node that represents the output of the function + size_t output_size; + + std::vector insns; + std::vector inputs; // Indices into the program inputs + size_t output_buffer; + }; + + struct BufferDescriptor + { + std::variant id; // Either a tensor or a function index + size_t size_elts; + }; -void CodegenNode(codegen::Program &prog, GraphNodeHandle node, std::optional output_buffer = std::nullopt); -codegen::Program CodegenNode(GraphNodeHandle node); + struct Program + { + void PushFunction(FunctionBuilder function) + { + functions.emplace_back(std::move(function)); + functions.back().output_buffer = AddBuffer(functions.size() - 1); + node_function_cache[functions.back().node.node_idx] = functions.size() - 1; + } + + size_t NumFunctions() + { + return functions.size(); + } + + size_t AddBuffer(GraphNodeHandle t, size_t size_elts) + { + if (t->Kind() != GraphNode::Kind::Tensor) + throw std::domain_error("Cannot AddBuffer on non-tensor"); + + for (size_t iinput = 0; iinput < buffers.size(); iinput++) + { + const auto &buff_id = buffers[iinput].id; + if (std::holds_alternative(buff_id)) + if (std::get(buff_id).node_idx == t.node_idx) + return iinput; + } + buffers.push_back({t, size_elts}); + return buffers.size() - 1; + } + + size_t AddBuffer(const size_t fn_idx) + { + for (size_t iinput = 0; iinput < buffers.size(); iinput++) + { + const auto &buff_id = buffers[iinput].id; + if (std::holds_alternative(buff_id)) + if (std::get(buff_id) == fn_idx) + return iinput; + } + buffers.push_back({fn_idx, functions[fn_idx].output_size}); + return buffers.size() - 1; + } + + size_t GetOutputBufferForNodeIdx(size_t node_id) + { + size_t function_id = node_function_cache[node_id]; + return functions[function_id].output_buffer; + } + + void ChangeOutputBuffer(size_t fn_idx, size_t new_output_buffer) + { + if (new_output_buffer >= buffers.size()) + throw std::domain_error("Invalid output buffer"); + functions[fn_idx].output_buffer = new_output_buffer; + } + + void Print() + { + for (size_t i = 0; i < functions.size(); i++) + { + std::printf("BEGIN FUNCTION %zu\n", i); + functions[i].Print(); + std::printf("END FUNCTION %zu\n", i); + } + } + + std::unordered_map node_function_cache; + std::vector functions; + std::vector buffers; + }; + + void CodegenNode(codegen::Program &prog, GraphNodeHandle node, std::optional output_buffer = std::nullopt); + codegen::Program CodegenNode(GraphNodeHandle node); -} + } } diff --git a/src/graph.cpp b/src/graph.cpp index e80faf8..72e347e 100644 --- a/src/graph.cpp +++ b/src/graph.cpp @@ -8,880 +8,943 @@ namespace gigagrad { -static dim_t FixDim(dim_t dim, dim_t mod) -{ - auto fixed_dim = ((dim % mod) + mod) % mod; - return fixed_dim; -} + static dim_t FixDim(dim_t dim, dim_t mod) + { + auto fixed_dim = ((dim % mod) + mod) % mod; + return fixed_dim; + } -static Shape ComputeStrides(Shape shape) -{ - dim_t cur = 1; - for(ssize_t i = std::ssize(shape) - 1; i >= 0; i--) + static Shape ComputeStrides(Shape shape) { - auto tmp = shape[i]; - shape[i] = cur; - cur *= tmp; + dim_t cur = 1; + for (ssize_t i = std::ssize(shape) - 1; i >= 0; i--) + { + auto tmp = shape[i]; + shape[i] = cur; + cur *= tmp; + } + return shape; } - return shape; -} -static Shape ComputeBroadcastedShape(const Shape &x, const Shape &y) -{ - // Ensure x.size() >= y.size() - Shape larger = x.size() > y.size() ? x : y; - const Shape &smaller = x.size() > y.size() ? y : x; - - for(ssize_t i = 0; i < std::ssize(smaller); i++) - { - // Store the proper dimension in dim_x - auto &dim_x = larger[larger.size() - i - 1]; - const auto &dim_y = smaller[smaller.size() - i - 1]; - if(dim_x == 1 && dim_y != 1) - dim_x = dim_y; - else if(dim_x != 1 && dim_y == 1) - continue; - else if(dim_x == dim_y) - continue; - else - throw std::domain_error("Cannot broadcast incompatible shapes"); + std::string getCDatatype(NumericDataType dtype) + { + switch (dtype) + { + case NumericDataType::FLOAT64: + return "double"; + case NumericDataType::FLOAT32: + return "float"; + case NumericDataType::FLOAT16: + return "__fp16"; // Check compiler support for this type. + case NumericDataType::INT8: + return "int8_t"; + case NumericDataType::INT16: + return "int16_t"; + case NumericDataType::INT32: + return "int32_t"; + case NumericDataType::INT64: + return "int64_t"; + default: + return "unknown"; + } } - return larger; -} -static Shape ComputeReducedShape(const ReduceOp &op) -{ - Shape shape = op.x.shape(); - if(op.dims.empty()) + static Shape ComputeBroadcastedShape(const Shape &x, const Shape &y) { - if(op.keepdim) + // Ensure x.size() >= y.size() + Shape larger = x.size() > y.size() ? x : y; + const Shape &smaller = x.size() > y.size() ? y : x; + + for (ssize_t i = 0; i < std::ssize(smaller); i++) { - std::fill(shape.begin(), shape.end(), 1); - return shape; + // Store the proper dimension in dim_x + auto &dim_x = larger[larger.size() - i - 1]; + const auto &dim_y = smaller[smaller.size() - i - 1]; + if (dim_x == 1 && dim_y != 1) + dim_x = dim_y; + else if (dim_x != 1 && dim_y == 1) + continue; + else if (dim_x == dim_y) + continue; + else + throw std::domain_error("Cannot broadcast incompatible shapes"); } - return {}; + return larger; } - if(op.dims.size() > shape.size()) - throw std::domain_error("Specified more dims to reduce on than there are dimensions in tensor"); - - for(auto dim : op.dims) - shape[dim] = -1; // Mark it as -1 for now. We'll either remove it or change it to 1 later + static Shape ComputeReducedShape(const ReduceOp &op) + { + Shape shape = op.x.shape(); + if (op.dims.empty()) + { + if (op.keepdim) + { + std::fill(shape.begin(), shape.end(), 1); + return shape; + } + return {}; + } + + if (op.dims.size() > shape.size()) + throw std::domain_error("Specified more dims to reduce on than there are dimensions in tensor"); - if(!op.keepdim) + for (auto dim : op.dims) + shape[dim] = -1; // Mark it as -1 for now. We'll either remove it or change it to 1 later + + if (!op.keepdim) + { + shape.erase(std::remove(shape.begin(), shape.end(), -1), shape.end()); + } + else + { + std::replace(shape.begin(), shape.end(), -1, 1); + } + return shape; + } + + static NumericDataType CompareDtypes(const NumericDataType &dtype_x, const NumericDataType &dtype_y) { - shape.erase(std::remove(shape.begin(), shape.end(), -1), shape.end()); + auto precedence = [](const NumericDataType &dtype) + { + switch (dtype) + { + case NumericDataType::FLOAT64: + return 8; // Highest precedence + case NumericDataType::FLOAT32: + return 7; + case NumericDataType::FLOAT16: + return 6; // Higher than INT64 despite its smaller size + case NumericDataType::INT64: + return 5; + case NumericDataType::INT32: + return 4; + case NumericDataType::INT16: + return 3; + case NumericDataType::INT8: + return 2; + } + return 0; // Default case, should not be reached + }; + + // Compare precedences + if (precedence(dtype_x) >= precedence(dtype_y)) + { + return dtype_x; + } + else + { + return dtype_y; + } } - else + + static GraphNodeHandle WrapInUnary(GraphNodeHandle x, UnaryOpType type) { - std::replace(shape.begin(), shape.end(), -1, 1); + Graph *graph = x.graph; + return graph->AddNode(UnaryOp{type, x}); } - return shape; -} -static GraphNodeHandle WrapInUnary(GraphNodeHandle x, UnaryOpType type) -{ - Graph *graph = x.graph; - return graph->AddNode(UnaryOp{type, x}); -} + static GraphNodeHandle WrapInReduction(GraphNodeHandle x, ReduceOpType type, Dims dims, bool keepdim) + { + Graph *graph = x.graph; + for (dim_t &d : dims) + d = FixDim(d, static_cast(x.shape().size())); + std::sort(dims.begin(), dims.end()); + return graph->AddNode(ReduceOp{type, x, std::move(dims), keepdim}); + } -static GraphNodeHandle WrapInReduction(GraphNodeHandle x, ReduceOpType type, Dims dims, bool keepdim) -{ - Graph *graph = x.graph; - for(dim_t &d : dims) - d = FixDim(d, static_cast(x.shape().size())); - std::sort(dims.begin(), dims.end()); - return graph->AddNode(ReduceOp{type, x, std::move(dims), keepdim}); -} + GraphNodeHandle GraphNodeHandle::sum(bool keepdim) const + { + Dims dims(this->shape().size()); + std::iota(dims.begin(), dims.end(), 0); + return this->sum(std::move(dims), keepdim); + } -GraphNodeHandle GraphNodeHandle::sum(bool keepdim) const -{ - Dims dims(this->shape().size()); - std::iota(dims.begin(), dims.end(), 0); - return this->sum(std::move(dims), keepdim); -} + GraphNodeHandle GraphNodeHandle::sum(dim_t dim, bool keepdim) const + { + return this->sum(Dims{dim}, keepdim); + } -GraphNodeHandle GraphNodeHandle::sum(dim_t dim, bool keepdim) const -{ - return this->sum(Dims{dim}, keepdim); -} + GraphNodeHandle GraphNodeHandle::sum(Dims dims, bool keepdim) const + { + return WrapInReduction(*this, ReduceOpType::SUM, std::move(dims), keepdim); + } -GraphNodeHandle GraphNodeHandle::sum(Dims dims, bool keepdim) const -{ - return WrapInReduction(*this, ReduceOpType::SUM, std::move(dims), keepdim); -} + GraphNodeHandle GraphNodeHandle::max(bool keepdim) const + { + Dims dims(this->shape().size()); + std::iota(dims.begin(), dims.end(), 0); + return this->max(std::move(dims), keepdim); + } -GraphNodeHandle GraphNodeHandle::max(bool keepdim) const -{ - Dims dims(this->shape().size()); - std::iota(dims.begin(), dims.end(), 0); - return this->max(std::move(dims), keepdim); -} + GraphNodeHandle GraphNodeHandle::max(dim_t dim, bool keepdim) const + { + return this->max(Dims{dim}, keepdim); + } -GraphNodeHandle GraphNodeHandle::max(dim_t dim, bool keepdim) const -{ - return this->max(Dims{dim}, keepdim); -} + GraphNodeHandle GraphNodeHandle::max(Dims dims, bool keepdim) const + { + return WrapInReduction(*this, ReduceOpType::MAX, std::move(dims), keepdim); + } -GraphNodeHandle GraphNodeHandle::max(Dims dims, bool keepdim) const -{ - return WrapInReduction(*this, ReduceOpType::MAX, std::move(dims), keepdim); -} + GraphNodeHandle GraphNodeHandle::reshape(Shape new_shape) const + { + Shape input_shape = this->shape(); + auto num_elements = std::accumulate(input_shape.begin(), input_shape.end(), dim_t{1}, std::multiplies{}); + auto num_implicit_dims = std::count(new_shape.begin(), new_shape.end(), -1); + if (num_implicit_dims == 0) + { + auto new_num_elements = std::accumulate(new_shape.begin(), new_shape.end(), dim_t{1}, std::multiplies{}); + if (new_num_elements != num_elements) + throw std::domain_error("Reshape number of elements doesn't match that of input tensor"); + Shape strides = ComputeStrides(new_shape); + return this->as_strided(std::move(new_shape), std::move(strides), 0); + } + + if (num_implicit_dims > 1) + throw std::domain_error("Reshape can have at most one implicit dimension"); + + auto num_elems_not_including_implicit_dim = std::accumulate( + new_shape.begin(), + new_shape.end(), + dim_t{1}, + [](auto x, auto y) + { + if (y == -1) + return x; + return x * y; + }); + auto remaining_dim = num_elements / num_elems_not_including_implicit_dim; + for (auto &x : new_shape) + if (x == -1) + x = remaining_dim; -GraphNodeHandle GraphNodeHandle::reshape(Shape new_shape) const -{ - Shape input_shape = this->shape(); - auto num_elements = std::accumulate(input_shape.begin(), input_shape.end(), dim_t{1}, std::multiplies{}); - auto num_implicit_dims = std::count(new_shape.begin(), new_shape.end(), -1); - if(num_implicit_dims == 0) - { - auto new_num_elements = std::accumulate(new_shape.begin(), new_shape.end(), dim_t{1}, std::multiplies{}); - if(new_num_elements != num_elements) - throw std::domain_error("Reshape number of elements doesn't match that of input tensor"); Shape strides = ComputeStrides(new_shape); return this->as_strided(std::move(new_shape), std::move(strides), 0); } - if(num_implicit_dims > 1) - throw std::domain_error("Reshape can have at most one implicit dimension"); - - auto num_elems_not_including_implicit_dim = std::accumulate( - new_shape.begin(), - new_shape.end(), - dim_t{1}, - [](auto x, auto y) - { - if(y == -1) - return x; - return x * y; - }); - auto remaining_dim = num_elements / num_elems_not_including_implicit_dim; - for(auto &x : new_shape) - if(x == -1) - x = remaining_dim; - - Shape strides = ComputeStrides(new_shape); - return this->as_strided(std::move(new_shape), std::move(strides), 0); -} - -GraphNodeHandle GraphNodeHandle::reshape(dim_t length) const -{ - return this->reshape(Shape{length}); -} + GraphNodeHandle GraphNodeHandle::reshape(dim_t length) const + { + return this->reshape(Shape{length}); + } -GraphNodeHandle GraphNodeHandle::swapaxes(dim_t axis1, dim_t axis2) const -{ - Shape shape = this->shape(); - axis1 = FixDim(axis1, shape.size()); - axis2 = FixDim(axis2, shape.size()); - std::swap(shape[axis1], shape[axis2]); - Shape strides = ComputeStrides(shape); - return this->as_strided(std::move(shape), std::move(strides), 0); -} + GraphNodeHandle GraphNodeHandle::swapaxes(dim_t axis1, dim_t axis2) const + { + Shape shape = this->shape(); + axis1 = FixDim(axis1, shape.size()); + axis2 = FixDim(axis2, shape.size()); + std::swap(shape[axis1], shape[axis2]); + Shape strides = ComputeStrides(shape); + return this->as_strided(std::move(shape), std::move(strides), 0); + } -GraphNodeHandle GraphNodeHandle::permute(Dims dims) const -{ - Shape shape = this->shape(); - if(dims.size() != shape.size()) - throw std::domain_error("Permute not given proper number of dimensions"); - std::vector uniqueness(shape.size(), false); - Shape new_shape(shape.size()); - for(size_t i = 0; i < shape.size(); i++) - { - // If dim is negative, we need to fix it to be between 0 and shape.size() - auto dim = dims[i]; - auto fixed_dim = FixDim(dim, static_cast(shape.size())); - if(uniqueness[fixed_dim]) - throw std::domain_error("Found repeated dim in permute"); - uniqueness[fixed_dim] = true; - new_shape[fixed_dim] = shape[i]; - } - Shape strides = ComputeStrides(new_shape); - return this->as_strided(std::move(new_shape), std::move(strides), 0); -} + GraphNodeHandle GraphNodeHandle::permute(Dims dims) const + { + Shape shape = this->shape(); + if (dims.size() != shape.size()) + throw std::domain_error("Permute not given proper number of dimensions"); + std::vector uniqueness(shape.size(), false); + Shape new_shape(shape.size()); + for (size_t i = 0; i < shape.size(); i++) + { + // If dim is negative, we need to fix it to be between 0 and shape.size() + auto dim = dims[i]; + auto fixed_dim = FixDim(dim, static_cast(shape.size())); + if (uniqueness[fixed_dim]) + throw std::domain_error("Found repeated dim in permute"); + uniqueness[fixed_dim] = true; + new_shape[fixed_dim] = shape[i]; + } + Shape strides = ComputeStrides(new_shape); + return this->as_strided(std::move(new_shape), std::move(strides), 0); + } -GraphNodeHandle GraphNodeHandle::transpose() const -{ - Shape shape = this->shape(); - Dims dims(shape.size()); - std::iota(std::rbegin(dims), std::rend(dims), 0); - return this->permute(std::move(dims)); -} + GraphNodeHandle GraphNodeHandle::transpose() const + { + Shape shape = this->shape(); + Dims dims(shape.size()); + std::iota(std::rbegin(dims), std::rend(dims), 0); + return this->permute(std::move(dims)); + } -GraphNodeHandle GraphNodeHandle::as_strided(Shape shape, Shape strides, dim_t offset) const -{ - return graph->AddNode(ViewOp{*this, std::move(shape), std::move(strides), offset}); -} + GraphNodeHandle GraphNodeHandle::as_strided(Shape shape, Shape strides, dim_t offset) const + { + return graph->AddNode(ViewOp{*this, std::move(shape), std::move(strides), offset}); + } -GraphNodeHandle GraphNodeHandle::relu() const -{ - return gigagrad::relu(GraphNodeHandle{*this}); -} + GraphNodeHandle GraphNodeHandle::relu() const + { + return gigagrad::relu(GraphNodeHandle{*this}); + } -GraphNodeHandle GraphNodeHandle::softmax(dim_t axis) const -{ - return gigagrad::softmax(GraphNodeHandle{*this}, axis); -} + GraphNodeHandle GraphNodeHandle::softmax(dim_t axis) const + { + return gigagrad::softmax(GraphNodeHandle{*this}, axis); + } -GraphNodeHandle GraphNodeHandle::mean(dim_t axis, bool keepdim) const -{ - return gigagrad::mean(*this, axis, keepdim); -} + GraphNodeHandle GraphNodeHandle::mean(dim_t axis, bool keepdim) const + { + return gigagrad::mean(*this, axis, keepdim); + } -GraphNodeHandle GraphNodeHandle::variance(dim_t axis, bool keepdim) const -{ - return gigagrad::variance(*this, axis, keepdim); -} + GraphNodeHandle GraphNodeHandle::variance(dim_t axis, bool keepdim) const + { + return gigagrad::variance(*this, axis, keepdim); + } -GraphNodeHandle GraphNodeHandle::batchnorm() const -{ - return gigagrad::batchnorm(*this); -} + GraphNodeHandle GraphNodeHandle::batchnorm() const + { + return gigagrad::batchnorm(*this); + } -// Matmul is a little tricky. We abuse the broadcasting semantics as follows: -// If we have matrices X, Y of shape AxB and BxC, then we reshape X into a -// AxBx1 tensor, and reshape Y into a 1xBxC tensor. Broadcasting then turns this -// into a cube of multiplications, and then we reduce along the middle axis -// and cut out the middle axis (since it has dim 1 anyway) -GraphNodeHandle GraphNodeHandle::matmul(GraphNodeHandle y) const -{ - Shape x_shape = this->shape(); - Shape y_shape = y.shape(); - - // Special case for 1-D vectors by padding them up to 2D - if(x_shape.size() == 1) - x_shape.insert(x_shape.begin(), 1); - if(y_shape.size() == 1) - y_shape.push_back(1); - - if(x_shape.size() < 2 || y_shape.size() < 2) - throw std::domain_error("Shapes must be at least of size 2 for matmul"); - - x_shape.push_back(1); - y_shape.insert(y_shape.end() - 2, 1); - if(*(x_shape.end() - 2) != *(y_shape.end() - 2)) - throw std::domain_error("Incompatible shapes in matmul"); - - GraphNodeHandle x_reshaped = this->reshape(std::move(x_shape)); - GraphNodeHandle y_reshaped = y.reshape(std::move(y_shape)); - GraphNodeHandle elementwise_mul = x_reshaped * y_reshaped; - return elementwise_mul.sum(-2, false /* keepdim */); // Sum along the middle axis -} + // Matmul is a little tricky. We abuse the broadcasting semantics as follows: + // If we have matrices X, Y of shape AxB and BxC, then we reshape X into a + // AxBx1 tensor, and reshape Y into a 1xBxC tensor. Broadcasting then turns this + // into a cube of multiplications, and then we reduce along the middle axis + // and cut out the middle axis (since it has dim 1 anyway) + GraphNodeHandle GraphNodeHandle::matmul(GraphNodeHandle y) const + { + Shape x_shape = this->shape(); + Shape y_shape = y.shape(); + + // Special case for 1-D vectors by padding them up to 2D + if (x_shape.size() == 1) + x_shape.insert(x_shape.begin(), 1); + if (y_shape.size() == 1) + y_shape.push_back(1); + + if (x_shape.size() < 2 || y_shape.size() < 2) + throw std::domain_error("Shapes must be at least of size 2 for matmul"); + + x_shape.push_back(1); + y_shape.insert(y_shape.end() - 2, 1); + if (*(x_shape.end() - 2) != *(y_shape.end() - 2)) + throw std::domain_error("Incompatible shapes in matmul"); + + GraphNodeHandle x_reshaped = this->reshape(std::move(x_shape)); + GraphNodeHandle y_reshaped = y.reshape(std::move(y_shape)); + GraphNodeHandle elementwise_mul = x_reshaped * y_reshaped; + return elementwise_mul.sum(-2, false /* keepdim */); // Sum along the middle axis + } -GraphNodeHandle sqrt(GraphNodeHandle x) -{ - return WrapInUnary(x, UnaryOpType::SQRT); -} + GraphNodeHandle sqrt(GraphNodeHandle x) + { + return WrapInUnary(x, UnaryOpType::SQRT); + } -GraphNodeHandle exp(GraphNodeHandle x) -{ - return WrapInUnary(x, UnaryOpType::EXP); -} + GraphNodeHandle exp(GraphNodeHandle x) + { + return WrapInUnary(x, UnaryOpType::EXP); + } -GraphNodeHandle log(GraphNodeHandle x) -{ - return WrapInUnary(x, UnaryOpType::LOG); -} + GraphNodeHandle log(GraphNodeHandle x) + { + return WrapInUnary(x, UnaryOpType::LOG); + } -GraphNodeHandle sin(GraphNodeHandle x) -{ - return WrapInUnary(x, UnaryOpType::SIN); -} + GraphNodeHandle sin(GraphNodeHandle x) + { + return WrapInUnary(x, UnaryOpType::SIN); + } -GraphNodeHandle cos(GraphNodeHandle x) -{ - return WrapInUnary((x + 3.14159265f/2.0f), UnaryOpType::SIN); -} + GraphNodeHandle cos(GraphNodeHandle x) + { + return WrapInUnary((x + 3.14159265f / 2.0f), UnaryOpType::SIN); + } -GraphNodeHandle sigmoid(GraphNodeHandle x) -{ - GraphNodeHandle expx = exp(x); - return expx / (1 + expx); -} + GraphNodeHandle sigmoid(GraphNodeHandle x) + { + GraphNodeHandle expx = exp(x); + return expx / (1 + expx); + } -GraphNodeHandle operator-(GraphNodeHandle x) -{ - return 0 - x; -} + GraphNodeHandle operator-(GraphNodeHandle x) + { + return 0 - x; + } -GraphNodeHandle operator+(GraphNodeHandle x, GraphNodeHandle y) -{ - Graph *graph = x.graph; - return graph->AddNode(BinaryOp{BinaryOpType::ADD, x, y}); -} + GraphNodeHandle operator+(GraphNodeHandle x, GraphNodeHandle y) + { + Graph *graph = x.graph; + return graph->AddNode(BinaryOp{BinaryOpType::ADD, x, y}); + } -GraphNodeHandle operator+(float x, GraphNodeHandle y) -{ - Graph *graph = y.graph; - GraphNodeHandle xnode = graph->AddNode(Immediate{x}); - return xnode + y; -} + GraphNodeHandle operator+(double x, GraphNodeHandle y) + { + Graph *graph = y.graph; + GraphNodeHandle xnode = graph->AddNode(Immediate{x}); + return xnode + y; + } -GraphNodeHandle operator+(GraphNodeHandle x, float y) -{ - return y + x; -} + GraphNodeHandle operator+(GraphNodeHandle x, double y) + { + return y + x; + } -GraphNodeHandle operator-(GraphNodeHandle x, GraphNodeHandle y) -{ - Graph *graph = x.graph; - return graph->AddNode(BinaryOp{BinaryOpType::SUB, x, y}); -} + GraphNodeHandle operator-(GraphNodeHandle x, GraphNodeHandle y) + { + Graph *graph = x.graph; + return graph->AddNode(BinaryOp{BinaryOpType::SUB, x, y}); + } -GraphNodeHandle operator-(float x, GraphNodeHandle y) -{ - return (-x) + y; -} + GraphNodeHandle operator-(double x, GraphNodeHandle y) + { + return (-x) + y; + } -GraphNodeHandle operator-(GraphNodeHandle x, float y) -{ - return x + (-y); -} + GraphNodeHandle operator-(GraphNodeHandle x, double y) + { + return x + (-y); + } -GraphNodeHandle operator*(GraphNodeHandle x, GraphNodeHandle y) -{ - Graph *graph = x.graph; - return graph->AddNode(BinaryOp{BinaryOpType::MUL, x, y}); -} + GraphNodeHandle operator*(GraphNodeHandle x, GraphNodeHandle y) + { + Graph *graph = x.graph; + return graph->AddNode(BinaryOp{BinaryOpType::MUL, x, y}); + } -GraphNodeHandle operator*(float x, GraphNodeHandle y) -{ - Graph *graph = y.graph; - GraphNodeHandle xnode = graph->AddNode(Immediate{x}); - return xnode * y; -} + GraphNodeHandle operator*(double x, GraphNodeHandle y) + { + Graph *graph = y.graph; + GraphNodeHandle xnode = graph->AddNode(Immediate{x}); + return xnode * y; + } -GraphNodeHandle operator*(GraphNodeHandle x, float y) -{ - return y * x; -} + GraphNodeHandle operator*(GraphNodeHandle x, double y) + { + return y * x; + } -GraphNodeHandle operator/(GraphNodeHandle x, GraphNodeHandle y) -{ - Graph *graph = x.graph; - return graph->AddNode(BinaryOp{BinaryOpType::DIV, x, y}); -} + GraphNodeHandle operator/(GraphNodeHandle x, GraphNodeHandle y) + { + Graph *graph = x.graph; + return graph->AddNode(BinaryOp{BinaryOpType::DIV, x, y}); + } -GraphNodeHandle operator/(float x, GraphNodeHandle y) -{ - Graph *graph = y.graph; - GraphNodeHandle xnode = graph->AddNode(Immediate{x}); - return xnode / y; -} + GraphNodeHandle operator/(double x, GraphNodeHandle y) + { + Graph *graph = y.graph; + GraphNodeHandle xnode = graph->AddNode(Immediate{x}); + return xnode / y; + } -GraphNodeHandle operator/(GraphNodeHandle x, float y) -{ - Graph *graph = x.graph; - GraphNodeHandle ynode = graph->AddNode(Immediate{y}); - return x / ynode; -} + GraphNodeHandle operator/(GraphNodeHandle x, double y) + { + Graph *graph = x.graph; + GraphNodeHandle ynode = graph->AddNode(Immediate{y}); + return x / ynode; + } -GraphNodeHandle operator^(GraphNodeHandle x, float y) -{ - Graph *graph = x.graph; - GraphNodeHandle ynode = graph->AddNode(Immediate{y}); - return graph->AddNode(BinaryOp{BinaryOpType::POW, x, ynode}); -} + GraphNodeHandle operator^(GraphNodeHandle x, double y) + { + Graph *graph = x.graph; + GraphNodeHandle ynode = graph->AddNode(Immediate{y}); + return graph->AddNode(BinaryOp{BinaryOpType::POW, x, ynode}); + } -GraphNodeHandle operator==(GraphNodeHandle x, GraphNodeHandle y) -{ - Graph *graph = x.graph; - return graph->AddNode(BinaryOp{BinaryOpType::CMP, x, y}); -} + GraphNodeHandle operator==(GraphNodeHandle x, GraphNodeHandle y) + { + Graph *graph = x.graph; + return graph->AddNode(BinaryOp{BinaryOpType::CMP, x, y}); + } -GraphNodeHandle operator==(float x, GraphNodeHandle y) -{ - Graph *graph = y.graph; - GraphNodeHandle xnode = graph->AddNode(Immediate{x}); - return xnode == y; -} + GraphNodeHandle operator==(double x, GraphNodeHandle y) + { + Graph *graph = y.graph; + GraphNodeHandle xnode = graph->AddNode(Immediate{x}); + return xnode == y; + } -GraphNodeHandle operator==(const GraphNodeHandle x, float y) -{ - return y == x; -} + GraphNodeHandle operator==(const GraphNodeHandle x, double y) + { + return y == x; + } -GraphNodeHandle operator<(const GraphNodeHandle x, float y) -{ - return y > x; -} + GraphNodeHandle operator<(const GraphNodeHandle x, double y) + { + return y > x; + } -GraphNodeHandle operator<(float x, const GraphNodeHandle y) -{ - return y > x; -} + GraphNodeHandle operator<(double x, const GraphNodeHandle y) + { + return y > x; + } -GraphNodeHandle operator<(GraphNodeHandle x, GraphNodeHandle y) -{ - return y > x; -} + GraphNodeHandle operator<(GraphNodeHandle x, GraphNodeHandle y) + { + return y > x; + } -GraphNodeHandle operator<=(GraphNodeHandle x, float y) -{ - return max(x - y, 0.0f) == 0.0f; -} + GraphNodeHandle operator<=(GraphNodeHandle x, double y) + { + return max(x - y, 0.0f) == 0.0f; + } -GraphNodeHandle operator<=(float x, const GraphNodeHandle y) -{ - return max(x - y, 0.0f) == 0.0f; -} + GraphNodeHandle operator<=(double x, const GraphNodeHandle y) + { + return max(x - y, 0.0f) == 0.0f; + } -GraphNodeHandle operator<=(GraphNodeHandle x, GraphNodeHandle y) -{ - return max(x - y, 0.0f) == 0.0f; -} + GraphNodeHandle operator<=(GraphNodeHandle x, GraphNodeHandle y) + { + return max(x - y, 0.0f) == 0.0f; + } -GraphNodeHandle operator>(GraphNodeHandle x, float y) -{ - return max(x, y) == x; -} + GraphNodeHandle operator>(GraphNodeHandle x, double y) + { + return max(x, y) == x; + } -GraphNodeHandle operator>(float x, GraphNodeHandle y) -{ - return max(x, y) == x; -} + GraphNodeHandle operator>(double x, GraphNodeHandle y) + { + return max(x, y) == x; + } -GraphNodeHandle operator>(GraphNodeHandle x, GraphNodeHandle y) -{ - return max(x, y) == x; -} + GraphNodeHandle operator>(GraphNodeHandle x, GraphNodeHandle y) + { + return max(x, y) == x; + } -GraphNodeHandle operator>=(GraphNodeHandle x, float y) -{ - return min(x - y, 0.0f) == 0.0f; -} + GraphNodeHandle operator>=(GraphNodeHandle x, double y) + { + return min(x - y, 0.0f) == 0.0f; + } -GraphNodeHandle operator>=(float x, GraphNodeHandle y) -{ - return min(x - y, 0.0f) == 0.0f; -} + GraphNodeHandle operator>=(double x, GraphNodeHandle y) + { + return min(x - y, 0.0f) == 0.0f; + } -GraphNodeHandle operator>=(GraphNodeHandle x, GraphNodeHandle y) -{ - return min(x - y, 0.0f) == 0.0f; -} + GraphNodeHandle operator>=(GraphNodeHandle x, GraphNodeHandle y) + { + return min(x - y, 0.0f) == 0.0f; + } -GraphNodeHandle max(GraphNodeHandle x, GraphNodeHandle y) -{ - Graph *graph = x.graph; - return graph->AddNode(BinaryOp{BinaryOpType::MAX, x, y}); -} + GraphNodeHandle max(GraphNodeHandle x, GraphNodeHandle y) + { + Graph *graph = x.graph; + return graph->AddNode(BinaryOp{BinaryOpType::MAX, x, y}); + } -GraphNodeHandle max(float x, GraphNodeHandle y) -{ - Graph *graph = y.graph; - GraphNodeHandle xnode = graph->AddNode(Immediate{x}); - return max(xnode, y); -} + GraphNodeHandle max(double x, GraphNodeHandle y) + { + Graph *graph = y.graph; + GraphNodeHandle xnode = graph->AddNode(Immediate{x}); + return max(xnode, y); + } -GraphNodeHandle max(GraphNodeHandle x, float y) -{ - return max(y, x); -} + GraphNodeHandle max(GraphNodeHandle x, double y) + { + return max(y, x); + } -GraphNodeHandle sum(GraphNodeHandle x, bool keepdim) -{ - return x.sum(keepdim); -} + GraphNodeHandle sum(GraphNodeHandle x, bool keepdim) + { + return x.sum(keepdim); + } -GraphNodeHandle min(GraphNodeHandle x, GraphNodeHandle y) -{ - return -max(-x, -y); -} + GraphNodeHandle min(GraphNodeHandle x, GraphNodeHandle y) + { + return -max(-x, -y); + } -GraphNodeHandle min(float x, GraphNodeHandle y) -{ - return -max(-x, -y); -} + GraphNodeHandle min(double x, GraphNodeHandle y) + { + return -max(-x, -y); + } -GraphNodeHandle min(GraphNodeHandle x, float y) -{ - return -max(-x, -y); -} + GraphNodeHandle min(GraphNodeHandle x, double y) + { + return -max(-x, -y); + } -GraphNodeHandle pow(GraphNodeHandle x, float y) -{ - Graph *graph = x.graph; - GraphNodeHandle ynode = graph->AddNode(Immediate{y}); - return graph->AddNode(BinaryOp{BinaryOpType::POW, x, ynode}); -} + GraphNodeHandle pow(GraphNodeHandle x, double y) + { + Graph *graph = x.graph; + GraphNodeHandle ynode = graph->AddNode(Immediate{y}); + return graph->AddNode(BinaryOp{BinaryOpType::POW, x, ynode}); + } -GraphNodeHandle pow(float x, GraphNodeHandle y) -{ - Graph *graph = y.graph; - GraphNodeHandle xnode = graph->AddNode(Immediate{x}); - return graph->AddNode(BinaryOp{BinaryOpType::POW, xnode, y}); -} + GraphNodeHandle pow(double x, GraphNodeHandle y) + { + Graph *graph = y.graph; + GraphNodeHandle xnode = graph->AddNode(Immediate{x}); + // NumericDataType numeric_data_type = xnode->numeric_data_type + return graph->AddNode(BinaryOp{BinaryOpType::POW, xnode, y}); + } -GraphNodeHandle pow(GraphNodeHandle x, GraphNodeHandle y) -{ - Graph *graph = x.graph; - return graph->AddNode(BinaryOp{BinaryOpType::POW, x, y}); -} + GraphNodeHandle pow(GraphNodeHandle x, GraphNodeHandle y) + { + Graph *graph = x.graph; + return graph->AddNode(BinaryOp{BinaryOpType::POW, x, y}); + } -GraphNodeHandle relu(GraphNodeHandle x) -{ - return max(x, 0.0f); -} + GraphNodeHandle relu(GraphNodeHandle x) + { + return max(x, 0.0f); + } -GraphNodeHandle softmax(GraphNodeHandle x, dim_t axis) -{ - GraphNodeHandle m = x.max(axis, true); - GraphNodeHandle exp_shifted = exp(x - m); - GraphNodeHandle sum_exp_shifted = exp_shifted.sum(axis, true); - return exp_shifted / sum_exp_shifted; -} + GraphNodeHandle softmax(GraphNodeHandle x, dim_t axis) + { + GraphNodeHandle m = x.max(axis, true); + GraphNodeHandle exp_shifted = exp(x - m); + GraphNodeHandle sum_exp_shifted = exp_shifted.sum(axis, true); + return exp_shifted / sum_exp_shifted; + } -GraphNodeHandle sum(GraphNodeHandle x, dim_t axis, bool keepdim) -{ - return x.sum(axis, keepdim); -} + GraphNodeHandle sum(GraphNodeHandle x, dim_t axis, bool keepdim) + { + return x.sum(axis, keepdim); + } -GraphNodeHandle sum(GraphNodeHandle x, Dims dims, bool keepdim) -{ - return x.sum(std::move(dims), keepdim); -} + GraphNodeHandle sum(GraphNodeHandle x, Dims dims, bool keepdim) + { + return x.sum(std::move(dims), keepdim); + } -GraphNodeHandle max(GraphNodeHandle x, bool keepdim) -{ - return x.max(keepdim); -} + GraphNodeHandle max(GraphNodeHandle x, bool keepdim) + { + return x.max(keepdim); + } -GraphNodeHandle max(GraphNodeHandle x, dim_t axis, bool keepdim) -{ - return x.max(axis, keepdim); -} + GraphNodeHandle max(GraphNodeHandle x, dim_t axis, bool keepdim) + { + return x.max(axis, keepdim); + } -GraphNodeHandle max(GraphNodeHandle x, Dims dims, bool keepdim) -{ - return x.max(std::move(dims), keepdim); -} + GraphNodeHandle max(GraphNodeHandle x, Dims dims, bool keepdim) + { + return x.max(std::move(dims), keepdim); + } -GraphNodeHandle min(GraphNodeHandle x, bool keepdim) -{ - return -max(-x, keepdim); -} + GraphNodeHandle min(GraphNodeHandle x, bool keepdim) + { + return -max(-x, keepdim); + } -GraphNodeHandle min(GraphNodeHandle x, dim_t axis, bool keepdim) -{ - return -max(-x, axis, keepdim); -} + GraphNodeHandle min(GraphNodeHandle x, dim_t axis, bool keepdim) + { + return -max(-x, axis, keepdim); + } -GraphNodeHandle min(GraphNodeHandle x, Dims dims, bool keepdim) -{ - return -max(-x, std::move(dims), keepdim); -} + GraphNodeHandle min(GraphNodeHandle x, Dims dims, bool keepdim) + { + return -max(-x, std::move(dims), keepdim); + } -GraphNodeHandle mean(GraphNodeHandle x, dim_t axis, bool keepdim) -{ - axis = FixDim(axis, x.shape().size()); - float denom = x.shape()[axis]; - GraphNodeHandle div = x / denom; - GraphNodeHandle result = div.sum(axis, keepdim); - return result; -} + GraphNodeHandle mean(GraphNodeHandle x, dim_t axis, bool keepdim) + { + axis = FixDim(axis, x.shape().size()); + float denom = x.shape()[axis]; + GraphNodeHandle div = x / denom; + GraphNodeHandle result = div.sum(axis, keepdim); + return result; + } -GraphNodeHandle variance(GraphNodeHandle x, dim_t axis, bool keepdim) -{ - GraphNodeHandle mean = x.mean(axis, true); - GraphNodeHandle errors = x - mean; - GraphNodeHandle square_errors = errors * errors; - GraphNodeHandle sum_square_errors = square_errors.sum(axis, keepdim); - return sum_square_errors; -} + GraphNodeHandle variance(GraphNodeHandle x, dim_t axis, bool keepdim) + { + GraphNodeHandle mean = x.mean(axis, true); + GraphNodeHandle errors = x - mean; + GraphNodeHandle square_errors = errors * errors; + GraphNodeHandle sum_square_errors = square_errors.sum(axis, keepdim); + return sum_square_errors; + } -GraphNodeHandle batchnorm(GraphNodeHandle x) -{ - constexpr float epsilon = 0.001; - GraphNodeHandle mean = x.mean(0, true); - GraphNodeHandle errors = x - mean; - GraphNodeHandle square_errors = errors * errors; - GraphNodeHandle sum_square_errors = square_errors.sum(0, true) + epsilon; - GraphNodeHandle result = errors / sqrt(sum_square_errors); - result->needs_gradient = false; - return result; -} + GraphNodeHandle batchnorm(GraphNodeHandle x) + { + constexpr float epsilon = 0.001; + GraphNodeHandle mean = x.mean(0, true); + GraphNodeHandle errors = x - mean; + GraphNodeHandle square_errors = errors * errors; + GraphNodeHandle sum_square_errors = square_errors.sum(0, true) + epsilon; + GraphNodeHandle result = errors / sqrt(sum_square_errors); + result->needs_gradient = false; + return result; + } -GraphNodeHandle reshape(GraphNodeHandle x, Shape shape) -{ - return x.reshape(std::move(shape)); -} + GraphNodeHandle reshape(GraphNodeHandle x, Shape shape) + { + return x.reshape(std::move(shape)); + } -GraphNodeHandle reshape(GraphNodeHandle x, dim_t length) -{ - return x.reshape(length); -} + GraphNodeHandle reshape(GraphNodeHandle x, dim_t length) + { + return x.reshape(length); + } -GraphNodeHandle permute(GraphNodeHandle x, Dims permutation) -{ - return x.permute(std::move(permutation)); -} + GraphNodeHandle permute(GraphNodeHandle x, Dims permutation) + { + return x.permute(std::move(permutation)); + } -GraphNodeHandle transpose(GraphNodeHandle x) -{ - return x.transpose(); -} + GraphNodeHandle transpose(GraphNodeHandle x) + { + return x.transpose(); + } -GraphNodeHandle operator%(GraphNodeHandle x, GraphNodeHandle y) -{ - return x.matmul(y); -} + GraphNodeHandle operator%(GraphNodeHandle x, GraphNodeHandle y) + { + return x.matmul(y); + } -GraphNodeHandle matmul(GraphNodeHandle x, GraphNodeHandle y) -{ - return x.matmul(y); -} + GraphNodeHandle matmul(GraphNodeHandle x, GraphNodeHandle y) + { + return x.matmul(y); + } -GraphNodeHandle Graph::Immediate(float imm) -{ - return this->AddNode(gigagrad::Immediate{imm}); -} + GraphNodeHandle Graph::Immediate(double imm) + { + return this->AddNode(gigagrad::Immediate{imm}); + } -GraphNodeHandle Graph::AddInput(Shape shape) -{ - this->inputs.push_back(this->nodes.size()); - GraphNodeHandle result = this->AddNode(Tensor{}, std::move(shape)); - return result; -} + GraphNodeHandle Graph::AddInput(Shape shape, NumericDataType dtype) + { + this->inputs.push_back(this->nodes.size()); + GraphNodeHandle result = this->AddNode(Tensor{}, std::move(shape), std::move(dtype)); + return result; + } -GraphNodeHandle Graph::AddInput(dim_t dim) -{ - return this->AddInput(Shape{dim}); -} + GraphNodeHandle Graph::AddInput(dim_t dim, NumericDataType dtype) + { + return this->AddInput(Shape{dim}, NumericDataType::FLOAT32); + } -GraphNodeHandle Graph::AddNode(Tensor tensor, Shape shape) -{ - Shape strides = ComputeStrides(shape); - return this->AddNode( - GraphNode - { - .u = { std::move(tensor) }, - .shape = std::move(shape), - .strides = std::move(strides), - }); -} + GraphNodeHandle Graph::AddNode(Tensor tensor, Shape shape, NumericDataType dtype) + { + Shape strides = ComputeStrides(shape); + return this->AddNode( + GraphNode{ + .u = {std::move(tensor)}, + .shape = std::move(shape), + .strides = std::move(strides), + .dtype = std::move(dtype), + }); + } -GraphNodeHandle Graph::AddNode(struct Immediate imm) -{ - return this->AddNode( - GraphNode - { - .u = { std::move(imm) }, - .shape = {}, - .strides = {}, - }); -} + GraphNodeHandle Graph::AddNode(struct Immediate imm) + { + return this->AddNode( + GraphNode{ + .u = {std::move(imm)}, + .shape = {}, + .strides = {}, + .dtype = {}, + }); + } -GraphNodeHandle Graph::AddNode(UnaryOp op) -{ - return this->AddNode( - GraphNode - { - .u = { std::move(op) }, - .shape = op.x.shape(), - .strides = op.x.strides(), - }); -} + GraphNodeHandle Graph::AddNode(UnaryOp op) + { + return this->AddNode( + GraphNode{ + .u = {std::move(op)}, + .shape = op.x.shape(), + .strides = op.x.strides(), + .dtype = op.x.dtype(), + }); + } -GraphNodeHandle Graph::AddNode(BinaryOp op) -{ - Shape shape = ComputeBroadcastedShape(op.x.shape(), op.y.shape()); - Shape strides = ComputeStrides(shape); - return this->AddNode( - GraphNode - { - .u = { std::move(op) }, - .shape = std::move(shape), - .strides = std::move(strides), - }); -} + GraphNodeHandle Graph::AddNode(BinaryOp op) + { + Shape shape = ComputeBroadcastedShape(op.x.shape(), op.y.shape()); + Shape strides = ComputeStrides(shape); + NumericDataType dtype = CompareDtypes(op.x.dtype(), op.y.dtype()); + return this->AddNode( + GraphNode{ + .u = {std::move(op)}, + .shape = std::move(shape), + .strides = std::move(strides), + .dtype = std::move(dtype)}); + } -GraphNodeHandle Graph::AddNode(ReduceOp op) -{ - Shape shape = ComputeReducedShape(op); - Shape strides = ComputeStrides(shape); - return this->AddNode( - GraphNode - { - .u = { std::move(op) }, - .shape = std::move(shape), - .strides = std::move(strides), - }); -} + GraphNodeHandle Graph::AddNode(ReduceOp op) + { + Shape shape = ComputeReducedShape(op); + Shape strides = ComputeStrides(shape); + return this->AddNode( + GraphNode{ + .u = {std::move(op)}, + .shape = std::move(shape), + .strides = std::move(strides), + }); + } -GraphNodeHandle Graph::AddNode(ViewOp op) -{ - if(op.shape.empty()) - throw std::logic_error("ViewOp has empty shape"); + GraphNodeHandle Graph::AddNode(ViewOp op) + { + if (op.shape.empty()) + throw std::logic_error("ViewOp has empty shape"); + + Shape shape = op.shape; + Shape strides = ComputeStrides(op.shape); + return this->AddNode( + GraphNode{ + .u = {std::move(op)}, + .shape = std::move(shape), + .strides = std::move(strides), + }); + } - Shape shape = op.shape; - Shape strides = ComputeStrides(op.shape); - return this->AddNode( - GraphNode - { - .u = { std::move(op) }, - .shape = std::move(shape), - .strides = std::move(strides), - }); -} + GraphNodeHandle Graph::AddNode(GraphNode node) + { + GraphNodeHandle result = {this, this->nodes.size()}; + this->nodes.emplace_back(std::move(node)); + return result; + } -GraphNodeHandle Graph::AddNode(GraphNode node) -{ - GraphNodeHandle result = { this, this->nodes.size() }; - this->nodes.emplace_back(std::move(node)); - return result; -} + const Shape &GraphNodeHandle::shape() const + { + const GraphNode &node = this->GetNode(); + return node.shape; + } -const Shape &GraphNodeHandle::shape() const -{ - const GraphNode &node = this->GetNode(); - return node.shape; -} + const Shape &GraphNodeHandle::strides() const + { + const GraphNode &node = this->GetNode(); + return node.strides; + } -const Shape &GraphNodeHandle::strides() const -{ - const GraphNode &node = this->GetNode(); - return node.strides; -} + const NumericDataType &GraphNodeHandle::dtype() const + { + const GraphNode &node = this->GetNode(); + return node.dtype; + } -GraphNode &GraphNodeHandle::GetNode() -{ - return graph->nodes[node_idx]; -} + GraphNode &GraphNodeHandle::GetNode() + { + return graph->nodes[node_idx]; + } -const GraphNode &GraphNodeHandle::GetNode() const -{ - return graph->nodes[node_idx]; -} + const GraphNode &GraphNodeHandle::GetNode() const + { + return graph->nodes[node_idx]; + } -float *&GraphNodeHandle::data() -{ - GraphNode &node = this->GetNode(); - if(node.u.k.kind != GraphNode::Kind::Tensor) - throw std::logic_error("Cannot call data() on non-Tensor node"); - return GetNode().u.t.tensor.data; -} + double *&GraphNodeHandle::data() + { + GraphNode &node = this->GetNode(); + if (node.u.k.kind != GraphNode::Kind::Tensor) + throw std::logic_error("Cannot call data() on non-Tensor node"); + return GetNode().u.t.tensor.data; + } -GraphNodeHandle nn::Module::Immediate(float imm) -{ - return this->graph.Immediate(imm); -} + GraphNodeHandle nn::Module::Immediate(double imm) + { + return this->graph.Immediate(imm); + } -GraphNodeHandle nn::Module::AddInput(Shape shape) -{ - return this->graph.AddInput(std::move(shape)); -} + GraphNodeHandle nn::Module::AddInput(Shape shape, NumericDataType dtype) + { + return this->graph.AddInput(std::move(shape), std::move(dtype)); + } -GraphNodeHandle nn::Module::AddInput(dim_t dim) -{ - return this->graph.AddInput(dim); -} + GraphNodeHandle nn::Module::AddInput(dim_t dim, NumericDataType dtype) + { + return this->graph.AddInput(dim, std::move(dtype)); + } -GraphNodeHandle nn::Module::AddWeight(Shape shape) -{ - this->weights.push_back(this->graph.inputs.size()); - return this->graph.AddInput(std::move(shape)); -} + GraphNodeHandle nn::Module::AddWeight(Shape shape, NumericDataType dtype) + { + this->weights.push_back(this->graph.inputs.size()); + return this->graph.AddInput(std::move(shape), std::move(dtype)); + } -GraphNodeHandle nn::Module::AddWeight(dim_t dim) -{ - this->weights.push_back(this->graph.inputs.size()); - return this->graph.AddInput(dim); -} + GraphNodeHandle nn::Module::AddWeight(dim_t dim, NumericDataType dtype) + { + this->weights.push_back(this->graph.inputs.size()); + return this->graph.AddInput(dim, std::move(dtype)); + } -GraphNode::U::U(const U &that) : k({ that.k.kind }) -{ - switch(this->k.kind) - { - case Kind::Tensor: - new (&this->t.tensor) Tensor(that.t.tensor); - break; - case Kind::Immediate: - new (&this->i.immediate) Immediate(that.i.immediate); - break; - case Kind::UnaryOp: - new (&this->u.unary_op) UnaryOp(that.u.unary_op); - break; - case Kind::BinaryOp: - new (&this->b.binary_op) BinaryOp(that.b.binary_op); - break; - case Kind::ReduceOp: - new (&this->r.reduce_op) ReduceOp(that.r.reduce_op); - break; - case Kind::ViewOp: - new (&this->v.view_op) ViewOp(that.v.view_op); - break; - default: - throw std::logic_error("Invalid node type!"); + GraphNode::U::U(const U &that) : k({that.k.kind}) + { + switch (this->k.kind) + { + case Kind::Tensor: + new (&this->t.tensor) Tensor(that.t.tensor); + break; + case Kind::Immediate: + new (&this->i.immediate) Immediate(that.i.immediate); + break; + case Kind::UnaryOp: + new (&this->u.unary_op) UnaryOp(that.u.unary_op); + break; + case Kind::BinaryOp: + new (&this->b.binary_op) BinaryOp(that.b.binary_op); + break; + case Kind::ReduceOp: + new (&this->r.reduce_op) ReduceOp(that.r.reduce_op); + break; + case Kind::ViewOp: + new (&this->v.view_op) ViewOp(that.v.view_op); + break; + default: + throw std::logic_error("Invalid node type!"); + } } -} -GraphNode::U::U(U &&that) : k({ that.k.kind }) -{ - switch(this->k.kind) - { - case Kind::Tensor: - new (&this->t.tensor) Tensor(std::move(that.t.tensor)); - break; - case Kind::Immediate: - new (&this->i.immediate) Immediate(std::move(that.i.immediate)); - break; - case Kind::UnaryOp: - new (&this->u.unary_op) UnaryOp(std::move(that.u.unary_op)); - break; - case Kind::BinaryOp: - new (&this->b.binary_op) BinaryOp(std::move(that.b.binary_op)); - break; - case Kind::ReduceOp: - new (&this->r.reduce_op) ReduceOp(std::move(that.r.reduce_op)); - break; - case Kind::ViewOp: - new (&this->v.view_op) ViewOp(std::move(that.v.view_op)); - break; - default: - throw std::logic_error("Invalid node type!"); + GraphNode::U::U(U &&that) : k({that.k.kind}) + { + switch (this->k.kind) + { + case Kind::Tensor: + new (&this->t.tensor) Tensor(std::move(that.t.tensor)); + break; + case Kind::Immediate: + new (&this->i.immediate) Immediate(std::move(that.i.immediate)); + break; + case Kind::UnaryOp: + new (&this->u.unary_op) UnaryOp(std::move(that.u.unary_op)); + break; + case Kind::BinaryOp: + new (&this->b.binary_op) BinaryOp(std::move(that.b.binary_op)); + break; + case Kind::ReduceOp: + new (&this->r.reduce_op) ReduceOp(std::move(that.r.reduce_op)); + break; + case Kind::ViewOp: + new (&this->v.view_op) ViewOp(std::move(that.v.view_op)); + break; + default: + throw std::logic_error("Invalid node type!"); + } } -} -GraphNode::U &GraphNode::U::operator=(const U &that) -{ - this->~U(); - new (this) U(that); - return *this; -} + GraphNode::U &GraphNode::U::operator=(const U &that) + { + this->~U(); + new (this) U(that); + return *this; + } -GraphNode::U &GraphNode::U::operator=(U &&that) -{ - this->~U(); - new (this) U(std::move(that)); - return *this; -} + GraphNode::U &GraphNode::U::operator=(U &&that) + { + this->~U(); + new (this) U(std::move(that)); + return *this; + } -GraphNode::U::~U() -{ - switch(this->k.kind) - { - case Kind::Tensor: - this->t.tensor.~Tensor(); - break; - case Kind::Immediate: - this->i.immediate.~Immediate(); - break; - case Kind::UnaryOp: - this->u.unary_op.~UnaryOp(); - break; - case Kind::BinaryOp: - this->b.binary_op.~BinaryOp(); - break; - case Kind::ReduceOp: - this->r.reduce_op.~ReduceOp(); - break; - case Kind::ViewOp: - this->v.view_op.~ViewOp(); - break; - default: - break; + GraphNode::U::~U() + { + switch (this->k.kind) + { + case Kind::Tensor: + this->t.tensor.~Tensor(); + break; + case Kind::Immediate: + this->i.immediate.~Immediate(); + break; + case Kind::UnaryOp: + this->u.unary_op.~UnaryOp(); + break; + case Kind::BinaryOp: + this->b.binary_op.~BinaryOp(); + break; + case Kind::ReduceOp: + this->r.reduce_op.~ReduceOp(); + break; + case Kind::ViewOp: + this->v.view_op.~ViewOp(); + break; + default: + break; + } } -} } diff --git a/src/graph.h b/src/graph.h index a25f3a8..ac4588b 100644 --- a/src/graph.h +++ b/src/graph.h @@ -13,309 +13,352 @@ namespace gigagrad { -struct Tensor; -struct Immediate; -struct UnaryOp; -struct BinaryOp; -struct ReduceOp; - -struct Graph; -struct GraphNode; -using dim_t = ssize_t; -using Shape = std::vector; -using Dims = std::vector; - -struct CompiledTensor -{ - float *data; - Shape shape; - std::unique_ptr backend; + struct Tensor; + struct Immediate; + struct UnaryOp; + struct BinaryOp; + struct ReduceOp; + + struct Graph; + struct GraphNode; + using dim_t = ssize_t; + using Shape = std::vector; + using Dims = std::vector; + + enum class NumericDataType + { + FLOAT64, + FLOAT32, + FLOAT16, + INT8, + INT16, + INT32, + INT64, + }; - void Execute() { backend->Execute(); } -}; + std::string getCDatatype(NumericDataType dtype); -struct GraphNodeHandle -{ - Graph *graph; - size_t node_idx; - - GraphNodeHandle sum(bool keepdim = false) const; - GraphNodeHandle sum(dim_t axis, bool keepdim = false) const; - GraphNodeHandle sum(Dims dims, bool keepdim = false) const; - GraphNodeHandle max(bool keepdim = false) const; - GraphNodeHandle max(dim_t axis, bool keepdim = false) const; - GraphNodeHandle max(Dims dims, bool keepdim = false) const; - - GraphNodeHandle reshape(Shape shape) const; - GraphNodeHandle reshape(dim_t length) const; - GraphNodeHandle permute(Dims dims) const; - GraphNodeHandle swapaxes(dim_t axis1, dim_t axis2) const; - GraphNodeHandle transpose() const; - GraphNodeHandle as_strided(Shape shape, Shape strides, dim_t offset) const; - - GraphNodeHandle relu() const; - GraphNodeHandle softmax(dim_t axis = -1) const; - GraphNodeHandle mean(dim_t axis = 0, bool keepdim = false) const; - GraphNodeHandle variance(dim_t axis = 0, bool keepdim = false) const; - GraphNodeHandle batchnorm() const; - - GraphNodeHandle matmul(GraphNodeHandle y) const; - - const Shape &shape() const; // Empty shape means scalar - const Shape &strides() const; - - GraphNode &GetNode(); - const GraphNode &GetNode() const; - GraphNode &operator*() { return this->GetNode(); } - const GraphNode &operator*() const { return this->GetNode(); } - GraphNode *operator->() { return &this->GetNode(); } - const GraphNode *operator->() const { return &this->GetNode(); } - float *&data(); - - CompiledTensor Compile(std::unique_ptr backend) const; - template - CompiledTensor Compile() const { return Compile(std::make_unique()); } -}; - -enum class UnaryOpType -{ - NOP, - EXP, - LOG, - CAST, - SIN, - SQRT, -}; - -enum class BinaryOpType -{ - ADD, - SUB, - MUL, - DIV, - POW, - CMP, - MAX, -}; - -enum class ReduceOpType -{ - SUM, - MAX, -}; + struct CompiledTensor + { + double *data; + Shape shape; + std::unique_ptr backend; -struct Tensor -{ - float *data = nullptr; -}; + void Execute() { backend->Execute(); } + }; -struct Immediate -{ - float value; -}; + struct GraphNodeHandle + { + Graph *graph; + size_t node_idx; + + GraphNodeHandle sum(bool keepdim = false) const; + GraphNodeHandle sum(dim_t axis, bool keepdim = false) const; + GraphNodeHandle sum(Dims dims, bool keepdim = false) const; + GraphNodeHandle max(bool keepdim = false) const; + GraphNodeHandle max(dim_t axis, bool keepdim = false) const; + GraphNodeHandle max(Dims dims, bool keepdim = false) const; + + GraphNodeHandle reshape(Shape shape) const; + GraphNodeHandle reshape(dim_t length) const; + GraphNodeHandle permute(Dims dims) const; + GraphNodeHandle swapaxes(dim_t axis1, dim_t axis2) const; + GraphNodeHandle transpose() const; + GraphNodeHandle as_strided(Shape shape, Shape strides, dim_t offset) const; + + GraphNodeHandle relu() const; + GraphNodeHandle softmax(dim_t axis = -1) const; + GraphNodeHandle mean(dim_t axis = 0, bool keepdim = false) const; + GraphNodeHandle variance(dim_t axis = 0, bool keepdim = false) const; + GraphNodeHandle batchnorm() const; + + GraphNodeHandle matmul(GraphNodeHandle y) const; + + const Shape &shape() const; // Empty shape means scalar + const Shape &strides() const; + const NumericDataType &dtype() const; + + GraphNode &GetNode(); + const GraphNode &GetNode() const; + GraphNode &operator*() { return this->GetNode(); } + const GraphNode &operator*() const { return this->GetNode(); } + GraphNode *operator->() { return &this->GetNode(); } + const GraphNode *operator->() const { return &this->GetNode(); } + double *&data(); + + CompiledTensor Compile(std::unique_ptr backend) const; + template + CompiledTensor Compile() const { return Compile(std::make_unique()); } + }; -struct UnaryOp -{ - UnaryOpType type; - GraphNodeHandle x; -}; + enum class UnaryOpType + { + NOP, + EXP, + LOG, + CAST, + SIN, + SQRT, + }; -struct BinaryOp -{ - BinaryOpType type; - GraphNodeHandle x; - GraphNodeHandle y; -}; + enum class BinaryOpType + { + ADD, + SUB, + MUL, + DIV, + POW, + CMP, + MAX, + }; -struct ReduceOp -{ - ReduceOpType type; - GraphNodeHandle x; - Dims dims; - bool keepdim; -}; + enum class ReduceOpType + { + SUM, + MAX, + }; -struct ViewOp -{ - GraphNodeHandle x; - Shape shape; - Shape strides; - dim_t offset; -}; + struct Tensor + { + double *data = nullptr; + NumericDataType dtype = NumericDataType::FLOAT32; + }; -struct GraphNode -{ - enum class Kind + struct Immediate { - Tensor, - Immediate, - UnaryOp, - BinaryOp, - ReduceOp, - ViewOp, + double value; }; - union U + struct UnaryOp { - struct { Kind kind; } k; - struct { Kind kind; Tensor tensor; } t; - struct { Kind kind; Immediate immediate; } i; - struct { Kind kind; UnaryOp unary_op; } u; - struct { Kind kind; BinaryOp binary_op; } b; - struct { Kind kind; ReduceOp reduce_op; } r; - struct { Kind kind; ViewOp view_op; } v; - - U(Tensor tensor) : t({ .kind = Kind::Tensor, .tensor = std::move(tensor) }) {} - U(Immediate immediate) : i({ .kind = Kind::Immediate, .immediate = std::move(immediate) }) {} - U(UnaryOp unary_op) : u({ .kind = Kind::UnaryOp, .unary_op = std::move(unary_op) }) {} - U(BinaryOp binary_op) : b({ .kind = Kind::BinaryOp, .binary_op = std::move(binary_op) }) {} - U(ReduceOp reduce_op) : r({ .kind = Kind::ReduceOp, .reduce_op = std::move(reduce_op) }) {} - U(ViewOp view_op) : v({ .kind = Kind::ViewOp, .view_op = std::move(view_op) }) {} - - U(const U &that); - U(U &&that); - U &operator=(const U &that); - U &operator=(U &&that); - ~U(); + UnaryOpType type; + GraphNodeHandle x; }; - template - decltype(auto) Visit(T fn) + struct BinaryOp { - switch(this->u.k.kind) + BinaryOpType type; + GraphNodeHandle x; + GraphNodeHandle y; + }; + + struct ReduceOp + { + ReduceOpType type; + GraphNodeHandle x; + Dims dims; + bool keepdim; + }; + + struct ViewOp + { + GraphNodeHandle x; + Shape shape; + Shape strides; + dim_t offset; + }; + + struct GraphNode + { + enum class Kind + { + Tensor, + Immediate, + UnaryOp, + BinaryOp, + ReduceOp, + ViewOp, + }; + + union U { - case Kind::Tensor: - return fn(this->u.t.tensor); - case Kind::Immediate: - return fn(this->u.i.immediate); - case Kind::UnaryOp: - return fn(this->u.u.unary_op); - case Kind::BinaryOp: - return fn(this->u.b.binary_op); - case Kind::ReduceOp: - return fn(this->u.r.reduce_op); - case Kind::ViewOp: - return fn(this->u.v.view_op); - default: - throw std::logic_error("Invalid node type! This is a bug"); + struct + { + Kind kind; + } k; + struct + { + Kind kind; + Tensor tensor; + } t; + struct + { + Kind kind; + Immediate immediate; + } i; + struct + { + Kind kind; + UnaryOp unary_op; + } u; + struct + { + Kind kind; + BinaryOp binary_op; + } b; + struct + { + Kind kind; + ReduceOp reduce_op; + } r; + struct + { + Kind kind; + ViewOp view_op; + } v; + + U(Tensor tensor) : t({.kind = Kind::Tensor, .tensor = std::move(tensor)}) {} + U(Immediate immediate) : i({.kind = Kind::Immediate, .immediate = std::move(immediate)}) {} + U(UnaryOp unary_op) : u({.kind = Kind::UnaryOp, .unary_op = std::move(unary_op)}) {} + U(BinaryOp binary_op) : b({.kind = Kind::BinaryOp, .binary_op = std::move(binary_op)}) {} + U(ReduceOp reduce_op) : r({.kind = Kind::ReduceOp, .reduce_op = std::move(reduce_op)}) {} + U(ViewOp view_op) : v({.kind = Kind::ViewOp, .view_op = std::move(view_op)}) {} + + U(const U &that); + U(U &&that); + U &operator=(const U &that); + U &operator=(U &&that); + ~U(); + }; + + template + decltype(auto) Visit(T fn) + { + switch (this->u.k.kind) + { + case Kind::Tensor: + return fn(this->u.t.tensor); + case Kind::Immediate: + return fn(this->u.i.immediate); + case Kind::UnaryOp: + return fn(this->u.u.unary_op); + case Kind::BinaryOp: + return fn(this->u.b.binary_op); + case Kind::ReduceOp: + return fn(this->u.r.reduce_op); + case Kind::ViewOp: + return fn(this->u.v.view_op); + default: + throw std::logic_error("Invalid node type! This is a bug"); + } } - } - Kind Kind() { return this->u.k.kind; } - - U u; - Shape shape; - Shape strides; - bool needs_gradient = true; -}; - -GraphNodeHandle sqrt(GraphNodeHandle x); -GraphNodeHandle exp(GraphNodeHandle x); -GraphNodeHandle log(GraphNodeHandle x); -GraphNodeHandle sin(GraphNodeHandle x); -GraphNodeHandle cos(GraphNodeHandle x); -GraphNodeHandle sigmoid(GraphNodeHandle x); -GraphNodeHandle operator-(GraphNodeHandle x); - -GraphNodeHandle operator+(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle operator+(float x, GraphNodeHandle y); -GraphNodeHandle operator+(GraphNodeHandle x, float y); -GraphNodeHandle operator-(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle operator-(float x, GraphNodeHandle y); -GraphNodeHandle operator-(GraphNodeHandle x, float y); -GraphNodeHandle operator*(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle operator*(float x, GraphNodeHandle y); -GraphNodeHandle operator*(GraphNodeHandle x, float y); -GraphNodeHandle operator/(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle operator/(float x, GraphNodeHandle y); -GraphNodeHandle operator/(GraphNodeHandle x, float y); -GraphNodeHandle operator^(GraphNodeHandle x, float y); -GraphNodeHandle operator==(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle operator==(GraphNodeHandle x, float y); -GraphNodeHandle operator==(float x, GraphNodeHandle y); -GraphNodeHandle operator<(GraphNodeHandle x, float y); -GraphNodeHandle operator<(float x, GraphNodeHandle y); -GraphNodeHandle operator<(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle operator<=(GraphNodeHandle x, float y); -GraphNodeHandle operator<=(float x, GraphNodeHandle y); -GraphNodeHandle operator<=(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle operator>(GraphNodeHandle x, float y); -GraphNodeHandle operator>(float x, GraphNodeHandle y); -GraphNodeHandle operator>(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle operator>=(GraphNodeHandle x, float y); -GraphNodeHandle operator>=(float x, GraphNodeHandle y); -GraphNodeHandle operator>=(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle max(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle max(float x, GraphNodeHandle y); -GraphNodeHandle max(GraphNodeHandle x, float y); -GraphNodeHandle min(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle min(float x, GraphNodeHandle y); -GraphNodeHandle min(GraphNodeHandle x, float y); -GraphNodeHandle pow(GraphNodeHandle x, float y); -GraphNodeHandle pow(float x, GraphNodeHandle y); -GraphNodeHandle pow(GraphNodeHandle x, GraphNodeHandle y); - -GraphNodeHandle relu(GraphNodeHandle x); -GraphNodeHandle softmax(GraphNodeHandle x, dim_t axis = -1); - -GraphNodeHandle sum(GraphNodeHandle x, bool keepdim = false); -GraphNodeHandle sum(GraphNodeHandle x, dim_t axis, bool keepdim = false); -GraphNodeHandle sum(GraphNodeHandle x, Dims dims, bool keepdim = false); -GraphNodeHandle max(GraphNodeHandle x, bool keepdim = false); -GraphNodeHandle max(GraphNodeHandle x, dim_t axis, bool keepdim = false); -GraphNodeHandle max(GraphNodeHandle x, Dims dims, bool keepdim = false); -GraphNodeHandle min(GraphNodeHandle x, bool keepdim = false); -GraphNodeHandle min(GraphNodeHandle x, dim_t axis, bool keepdim = false); -GraphNodeHandle min(GraphNodeHandle x, Dims dims, bool keepdim = false); - -GraphNodeHandle mean(GraphNodeHandle x, dim_t axis, bool keepdim = false); -GraphNodeHandle variance(GraphNodeHandle x, dim_t axis, bool keepdim = false); -GraphNodeHandle batchnorm(GraphNodeHandle x); - -GraphNodeHandle reshape(GraphNodeHandle x, Shape shape); -GraphNodeHandle reshape(GraphNodeHandle x, dim_t length); -GraphNodeHandle permute(GraphNodeHandle x, Dims dims); - -GraphNodeHandle operator%(GraphNodeHandle x, GraphNodeHandle y); -GraphNodeHandle matmul(GraphNodeHandle x, GraphNodeHandle y); - -struct Graph -{ - GraphNodeHandle Immediate(float imm); - GraphNodeHandle AddInput(Shape shape); - GraphNodeHandle AddInput(dim_t dim); + Kind Kind() { return this->u.k.kind; } - GraphNodeHandle AddNode(struct Tensor, Shape shape); - GraphNodeHandle AddNode(struct Immediate); - GraphNodeHandle AddNode(struct UnaryOp); - GraphNodeHandle AddNode(struct BinaryOp); - GraphNodeHandle AddNode(struct ReduceOp); - GraphNodeHandle AddNode(struct ViewOp); + U u; + Shape shape; + Shape strides; + NumericDataType dtype; + bool needs_gradient = true; + }; - GraphNodeHandle AddNode(GraphNode node); + GraphNodeHandle sqrt(GraphNodeHandle x); + GraphNodeHandle exp(GraphNodeHandle x); + GraphNodeHandle log(GraphNodeHandle x); + GraphNodeHandle sin(GraphNodeHandle x); + GraphNodeHandle cos(GraphNodeHandle x); + GraphNodeHandle sigmoid(GraphNodeHandle x); + GraphNodeHandle operator-(GraphNodeHandle x); + + GraphNodeHandle operator+(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle operator+(double x, GraphNodeHandle y); + GraphNodeHandle operator+(GraphNodeHandle x, double y); + GraphNodeHandle operator-(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle operator-(double x, GraphNodeHandle y); + GraphNodeHandle operator-(GraphNodeHandle x, double y); + GraphNodeHandle operator*(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle operator*(double x, GraphNodeHandle y); + GraphNodeHandle operator*(GraphNodeHandle x, double y); + GraphNodeHandle operator/(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle operator/(double x, GraphNodeHandle y); + GraphNodeHandle operator/(GraphNodeHandle x, double y); + GraphNodeHandle operator^(GraphNodeHandle x, double y); + GraphNodeHandle operator==(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle operator==(GraphNodeHandle x, double y); + GraphNodeHandle operator==(double x, GraphNodeHandle y); + GraphNodeHandle operator<(GraphNodeHandle x, double y); + GraphNodeHandle operator<(double x, GraphNodeHandle y); + GraphNodeHandle operator<(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle operator<=(GraphNodeHandle x, double y); + GraphNodeHandle operator<=(double x, GraphNodeHandle y); + GraphNodeHandle operator<=(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle operator>(GraphNodeHandle x, double y); + GraphNodeHandle operator>(double x, GraphNodeHandle y); + GraphNodeHandle operator>(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle operator>=(GraphNodeHandle x, double y); + GraphNodeHandle operator>=(double x, GraphNodeHandle y); + GraphNodeHandle operator>=(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle max(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle max(double x, GraphNodeHandle y); + GraphNodeHandle max(GraphNodeHandle x, double y); + GraphNodeHandle min(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle min(double x, GraphNodeHandle y); + GraphNodeHandle min(GraphNodeHandle x, double y); + GraphNodeHandle pow(GraphNodeHandle x, double y); + GraphNodeHandle pow(double x, GraphNodeHandle y); + GraphNodeHandle pow(GraphNodeHandle x, GraphNodeHandle y); + + GraphNodeHandle relu(GraphNodeHandle x); + GraphNodeHandle softmax(GraphNodeHandle x, dim_t axis = -1); + + GraphNodeHandle sum(GraphNodeHandle x, bool keepdim = false); + GraphNodeHandle sum(GraphNodeHandle x, dim_t axis, bool keepdim = false); + GraphNodeHandle sum(GraphNodeHandle x, Dims dims, bool keepdim = false); + GraphNodeHandle max(GraphNodeHandle x, bool keepdim = false); + GraphNodeHandle max(GraphNodeHandle x, dim_t axis, bool keepdim = false); + GraphNodeHandle max(GraphNodeHandle x, Dims dims, bool keepdim = false); + GraphNodeHandle min(GraphNodeHandle x, bool keepdim = false); + GraphNodeHandle min(GraphNodeHandle x, dim_t axis, bool keepdim = false); + GraphNodeHandle min(GraphNodeHandle x, Dims dims, bool keepdim = false); + + GraphNodeHandle mean(GraphNodeHandle x, dim_t axis, bool keepdim = false); + GraphNodeHandle variance(GraphNodeHandle x, dim_t axis, bool keepdim = false); + GraphNodeHandle batchnorm(GraphNodeHandle x); + + GraphNodeHandle reshape(GraphNodeHandle x, Shape shape); + GraphNodeHandle reshape(GraphNodeHandle x, dim_t length); + GraphNodeHandle permute(GraphNodeHandle x, Dims dims); + + GraphNodeHandle operator%(GraphNodeHandle x, GraphNodeHandle y); + GraphNodeHandle matmul(GraphNodeHandle x, GraphNodeHandle y); + + struct Graph + { + GraphNodeHandle Immediate(double imm); + GraphNodeHandle AddInput(Shape shape, NumericDataType dtype = NumericDataType::FLOAT32); + GraphNodeHandle AddInput(dim_t dim, NumericDataType dtype = NumericDataType::FLOAT32); - std::vector inputs; - std::deque nodes; -}; + GraphNodeHandle AddNode(Tensor tensor, Shape shape, NumericDataType dtype); + GraphNodeHandle AddNode(struct Immediate); + GraphNodeHandle AddNode(struct UnaryOp); + GraphNodeHandle AddNode(struct BinaryOp); + GraphNodeHandle AddNode(struct ReduceOp); + GraphNodeHandle AddNode(struct ViewOp); -namespace nn -{ + GraphNodeHandle AddNode(GraphNode node); -struct Module -{ - GraphNodeHandle Immediate(float imm); + std::vector inputs; + std::deque nodes; + }; + + namespace nn + { - GraphNodeHandle AddInput(Shape shape); - GraphNodeHandle AddInput(dim_t dim); + struct Module + { + GraphNodeHandle Immediate(double imm); - GraphNodeHandle AddWeight(Shape shape); - GraphNodeHandle AddWeight(dim_t dim); - - Graph graph; - // TODO: Think about if this is a good idea.. - // kind of cumbersome doing a double lookup - std::vector weights; // Indices of forward.inputs that are weights -}; + GraphNodeHandle AddInput(Shape shape, NumericDataType dtype = NumericDataType::FLOAT32); + GraphNodeHandle AddInput(dim_t dim, NumericDataType dtype = NumericDataType::FLOAT32); -} + GraphNodeHandle AddWeight(Shape shape, NumericDataType dtype = NumericDataType::FLOAT32); + GraphNodeHandle AddWeight(dim_t dim, NumericDataType dtype = NumericDataType::FLOAT32); + + Graph graph; + // TODO: Think about if this is a good idea.. + // kind of cumbersome doing a double lookup + std::vector weights; // Indices of forward.inputs that are weights + }; + + } } diff --git a/src/training.h b/src/training.h index 67cee0b..77de03a 100644 --- a/src/training.h +++ b/src/training.h @@ -6,26 +6,26 @@ namespace gigagrad { -struct TrainingContext -{ - float *loss; - float *&training_example; - std::unique_ptr backend; + struct TrainingContext + { + float *loss; + double *&training_example; + std::unique_ptr backend; - void Execute() { backend->Execute(); } -}; + void Execute() { backend->Execute(); } + }; -// TODO: Allow dynamic learning rate -TrainingContext CompileTrainingGraph( - nn::Module &network, - GraphNodeHandle model_output, - std::unique_ptr backend, - float learning_rate = 0.1f); + // TODO: Allow dynamic learning rate + TrainingContext CompileTrainingGraph( + nn::Module &network, + GraphNodeHandle model_output, + std::unique_ptr backend, + float learning_rate = 0.1f); -template -TrainingContext CompileTrainingGraph(nn::Module &network, GraphNodeHandle model_output, float learning_rate = 0.1f) -{ - return CompileTrainingGraph(network, model_output, std::make_unique(), learning_rate); -} + template + TrainingContext CompileTrainingGraph(nn::Module &network, GraphNodeHandle model_output, float learning_rate = 0.1f) + { + return CompileTrainingGraph(network, model_output, std::make_unique(), learning_rate); + } } diff --git a/test/gigagrad-emnist.cpp b/test/gigagrad-emnist.cpp index f759c89..460d6ed 100644 --- a/test/gigagrad-emnist.cpp +++ b/test/gigagrad-emnist.cpp @@ -24,7 +24,7 @@ enum class DataType : uint8_t size_t SizeOf(DataType dtype) { - switch(dtype) + switch (dtype) { case DataType::U8: return sizeof(uint8_t); @@ -53,23 +53,23 @@ struct ParsedDataFile ParsedDataFile LoadDataFile(const char *filename) { FILE *file = fopen(filename, "rb"); - if(!file) + if (!file) { fprintf(stderr, "File %s not found\n", filename); exit(1); } uint32_t magic = 0; - if(fseek(file, 0, SEEK_SET)) + if (fseek(file, 0, SEEK_SET)) { fprintf(stderr, "Failed to seek: %d\n", ferror(file)); exit(1); } - if(fread(&magic, sizeof(uint32_t), 1, file) < 1) + if (fread(&magic, sizeof(uint32_t), 1, file) < 1) { fprintf(stderr, "Failed to read from file: %d\n", ferror(file)); exit(1); } - if(static_cast(magic & 0xFFFF) != 0) + if (static_cast(magic & 0xFFFF) != 0) { fprintf(stderr, "Invalid magic bytes at beginning of file %s (0x%x), exiting\n", filename, magic); exit(1); @@ -77,7 +77,7 @@ ParsedDataFile LoadDataFile(const char *filename) uint8_t raw_dtype = (magic >> 16) & 0xFF; DataType dtype; - switch(raw_dtype) + switch (raw_dtype) { case (uint8_t)DataType::U8: dtype = DataType::U8; @@ -103,21 +103,20 @@ ParsedDataFile LoadDataFile(const char *filename) __builtin_unreachable(); } - uint8_t ndim = (magic >> 24) & 0xFF; std::vector shape(ndim); size_t offset = 4; size_t total_dataset_size = 1; - for(size_t i = 0; i < ndim; i++) + for (size_t i = 0; i < ndim; i++) { uint32_t dim = 0; - if(fseek(file, offset, SEEK_SET)) + if (fseek(file, offset, SEEK_SET)) { fprintf(stderr, "Failed to seek to offset %zu\n", offset); exit(1); } - if(fread(&dim, sizeof(dim), 1, file) < 1) + if (fread(&dim, sizeof(dim), 1, file) < 1) { fprintf(stderr, "Failed to read\n"); exit(1); @@ -128,36 +127,36 @@ ParsedDataFile LoadDataFile(const char *filename) } std::vector data(total_dataset_size * SizeOf(dtype), 0); - if(fseek(file, offset, SEEK_SET)) + if (fseek(file, offset, SEEK_SET)) { fprintf(stderr, "Failed to seek to offset %zu\n", offset); exit(1); } - if(fread(data.data(), SizeOf(dtype), total_dataset_size, file) < total_dataset_size) + if (fread(data.data(), SizeOf(dtype), total_dataset_size, file) < total_dataset_size) { fprintf(stderr, "Failed to read dataset\n"); exit(1); } - return { dtype, std::move(shape), std::move(data) }; + return {dtype, std::move(shape), std::move(data)}; } template -std::vector CastToFloatAndNormalize(const std::vector &input) +std::vector CastToFloatAndNormalize(const std::vector &input) { size_t num_elements = input.size() / sizeof(T); - std::vector result(num_elements); + std::vector result(num_elements); const T *input_data = reinterpret_cast(input.data()); - for(size_t i = 0; i < num_elements; i++) + for (size_t i = 0; i < num_elements; i++) { - result[i] = static_cast(input_data[i]) / 255.0; + result[i] = static_cast(input_data[i]) / 255.0; } return result; } -std::vector CastToFloatAndNormalize(DataType dtype, const std::vector &input) +std::vector CastToFloatAndNormalize(DataType dtype, const std::vector &input) { - switch(dtype) + switch (dtype) { case DataType::U8: return CastToFloatAndNormalize(input); @@ -177,22 +176,22 @@ std::vector CastToFloatAndNormalize(DataType dtype, const std::vector -std::vector ToOneHot(const std::vector &input) +std::vector ToOneHot(const std::vector &input) { size_t num_elements = input.size() / sizeof(T); const T *input_data = reinterpret_cast(input.data()); T max_val = 0; - for(size_t i = 0; i < num_elements; i++) + for (size_t i = 0; i < num_elements; i++) { - if(input_data[i] < 0) + if (input_data[i] < 0) { fprintf(stderr, "Tried to one-hot encode invalid value: %d\n", (int)input_data[i]); exit(1); } max_val = std::max(max_val, input_data[i]); } - std::vector result(max_val * num_elements, 0.0f); - for(size_t i = 0; i < num_elements; i++) + std::vector result(max_val * num_elements, 0.0f); + for (size_t i = 0; i < num_elements; i++) { T cur_val = input_data[i]; result[i * max_val + cur_val] = 1.0f; @@ -200,9 +199,9 @@ std::vector ToOneHot(const std::vector &input) return result; } -std::vector ToOneHot(DataType dtype, const std::vector &input) +std::vector ToOneHot(DataType dtype, const std::vector &input) { - switch(dtype) + switch (dtype) { case DataType::U8: return ToOneHot(input); @@ -221,8 +220,8 @@ std::vector ToOneHot(DataType dtype, const std::vector &input) struct Dataset { std::vector shape; - std::vector inputs; - std::vector labels; + std::vector inputs; + std::vector labels; }; Dataset LoadDataset(const char *directory, const char *dataset) @@ -231,22 +230,22 @@ Dataset LoadDataset(const char *directory, const char *dataset) std::string label_name = std::string(directory) + "/emnist-mnist-" + dataset + "-labels-idx1-ubyte"; ParsedDataFile images = LoadDataFile(image_name.c_str()); ParsedDataFile labels = LoadDataFile(label_name.c_str()); - std::vector images_float = CastToFloatAndNormalize(images.dtype, images.data); - std::vector labels_onehot = ToOneHot(labels.dtype, labels.data); - return { std::move(images.shape), std::move(images_float), std::move(labels_onehot) }; + std::vector images_float = CastToFloatAndNormalize(images.dtype, images.data); + std::vector labels_onehot = ToOneHot(labels.dtype, labels.data); + return {std::move(images.shape), std::move(images_float), std::move(labels_onehot)}; } -void InitializeWeights(float *weight, size_t size_elts) +void InitializeWeights(double *weight, size_t size_elts) { std::default_random_engine gen(0); - std::uniform_real_distribution dist(-0.1f, 0.1f); - for(size_t i = 0; i < size_elts; i++) + std::uniform_real_distribution dist(-0.1f, 0.1f); + for (size_t i = 0; i < size_elts; i++) weight[i] = dist(gen); } int main(int argc, const char **argv) { - if(argc != 2) + if (argc != 2) { fprintf(stderr, "Please specify exactly one argument: the directory of the EMNIST dataset\n"); exit(1); @@ -257,32 +256,32 @@ int main(int argc, const char **argv) constexpr size_t HiddenLayerSize = 40; gg::nn::Module network; - auto x = network.AddInput({ BatchSize, 28 * 28, 1 }); - auto w1 = network.AddWeight({ HiddenLayerSize, 28 * 28 }); - auto b1 = network.AddWeight({ HiddenLayerSize, 1 }); + auto x = network.AddInput({BatchSize, 28 * 28, 1}); + auto w1 = network.AddWeight({HiddenLayerSize, 28 * 28}); + auto b1 = network.AddWeight({HiddenLayerSize, 1}); auto x_bm = x.batchnorm(); auto z1 = (w1 % x_bm) + b1; auto a2 = z1.relu(); - auto w2 = network.AddWeight({ 10, HiddenLayerSize }); - auto b2 = network.AddWeight({ 10, 1 }); + auto w2 = network.AddWeight({10, HiddenLayerSize}); + auto b2 = network.AddWeight({10, 1}); auto z2 = (w2 % a2) + b2; auto result = z2.softmax(-2); gg::TrainingContext ctx = gg::CompileTrainingGraph(network, result, 0.005f); - w1.data() = new float[HiddenLayerSize * 28 * 28]; - b1.data() = new float[HiddenLayerSize * 1]; - w2.data() = new float[10 * HiddenLayerSize]; - b2.data() = new float[10 * 1]; + w1.data() = new double[HiddenLayerSize * 28 * 28]; + b1.data() = new double[HiddenLayerSize * 1]; + w2.data() = new double[10 * HiddenLayerSize]; + b2.data() = new double[10 * 1]; InitializeWeights(w1.data(), HiddenLayerSize * 28 * 28); InitializeWeights(b1.data(), HiddenLayerSize * 1); InitializeWeights(w2.data(), 10 * HiddenLayerSize); InitializeWeights(b2.data(), 10 * 1); size_t num_batches = train.shape[0] / BatchSize; - for(size_t iepoch = 0; iepoch < 100; iepoch++) + for (size_t iepoch = 0; iepoch < 100; iepoch++) { - for(size_t ibatch = 0; ibatch < num_batches; ibatch++) + for (size_t ibatch = 0; ibatch < num_batches; ibatch++) { x.data() = &train.inputs[BatchSize * 28 * 28 * ibatch]; ctx.training_example = &train.labels[BatchSize * 10 * ibatch]; diff --git a/test/graph-test.cpp b/test/graph-test.cpp index 68d6692..9219bf2 100644 --- a/test/graph-test.cpp +++ b/test/graph-test.cpp @@ -6,6 +6,7 @@ #include #include +#include namespace gg = gigagrad; @@ -13,27 +14,30 @@ void TestGradient( gg::nn::Module &network, gg::GraphNodeHandle w, gg::GraphNodeHandle result, - float expected) + double expected) { gg::TrainingContext ctx = gg::CompileTrainingGraph(network, result, 1.0); - float example = 0.0f; + double example = 0.0; ctx.training_example = &example; ctx.Execute(); - float pct_diff = std::abs(*w.data() - expected); + std::cout << "w.data(): " << *w.data() << std::endl; + std::cout << "expected: " << expected << std::endl; + double pct_diff = std::abs(*w.data() - expected); REQUIRE(pct_diff < 0.001); } TEST_CASE("TestGradients_EXP", "[Train]") { gg::nn::Module network; - auto w = network.AddWeight(1); + auto dtype = gg::NumericDataType::FLOAT32; + auto w = network.AddWeight(1, dtype); auto result = exp(w); - float w_data = 0.0f; + double w_data = 0.0; w.data() = &w_data; // ∂/∂w (E - exp(w))^2 = 2(E - exp(w)) * ∂/∂w(E - exp(w)) = 2(E - exp(w)) * -exp(w) // If E = 0, above equals 2exp(2w). If w = 0, above equals 2. So after gradient update, // w should be 0 - 2 = -2. - TestGradient(network, w, result, -2.0f); + TestGradient(network, w, result, -2.0); } TEST_CASE("TestGradients_LOG", "[Train]") @@ -41,12 +45,12 @@ TEST_CASE("TestGradients_LOG", "[Train]") gg::nn::Module network; auto w = network.AddWeight(1); auto result = log(w); - float w_data = 1.0f; + double w_data = 1.0f; w.data() = &w_data; // ∂/∂w (E - log(w))^2 = 2(E - log(w)) * ∂/∂w(E - log(w)) = 2(E - log(w)) * -1/w // If E = 0, above equals log(w)/w. If w = 1, above equals 0. So after gradient update, // w should be 1 - 0 = 1. - TestGradient(network, w, result, 1.0f); + TestGradient(network, w, result, 1.0); } TEST_CASE("TestGradients_SIN", "[Train]") @@ -54,12 +58,12 @@ TEST_CASE("TestGradients_SIN", "[Train]") gg::nn::Module network; auto w = network.AddWeight(1); auto result = sin(w); - float w_data = 0.0f; + double w_data = 0.0; w.data() = &w_data; // ∂/∂w (E - sin(w))^2 = 2(E - sin(w)) * ∂/∂w(E - sin(w)) = 2(E - sin(w)) * -cos(w) // If E = 0, above equals 2sin(w)cos(w). If w = 0, above equals 0. So after gradient update, // w should be 0 - 0 = 0. - TestGradient(network, w, result, 0.0f); + TestGradient(network, w, result, 0.0); } TEST_CASE("TestGradients_SQRT", "[Train]") @@ -67,12 +71,12 @@ TEST_CASE("TestGradients_SQRT", "[Train]") gg::nn::Module network; auto w = network.AddWeight(1); auto result = sqrt(w); - float w_data = 1.0f; + double w_data = 1.0; w.data() = &w_data; // ∂/∂w (E - sqrt(w))^2 = 2(E - sqrt(w)) * ∂/∂w(E - sqrt(w)) = 2(E - sqrt(w)) * (0 - 1/2 * 1/sqrt(w)) // If E = 0, above equals -2sqrt(w)/-2sqrt(w). If w = 1, above equals 1. So after gradient update, // w should be 1 - 1 = 0. - TestGradient(network, w, result, 0.0f); + TestGradient(network, w, result, 0.0); } TEST_CASE("TestGradients_ADD", "[Train]") @@ -81,30 +85,30 @@ TEST_CASE("TestGradients_ADD", "[Train]") auto x = network.AddInput(1); auto w = network.AddWeight(1); auto result = x + w; - float x_data = 1.0f; - float w_data = 1.0f; + double x_data = 1.0; + double w_data = 1.0; x.data() = &x_data; w.data() = &w_data; // ∂/∂w (E - (x + w))^2 = 2(E - x - w) * ∂/∂w(E - x - w) = 2(E - x - w) * (0 - 0 - 1) = -2(E - x - w) // If E = 0, above equals 2(x + w). If x,w = 1, above equals 4. So after gradient update, // w should be 1 - 4 = -3.0f. - TestGradient(network, w, result, -3.0f); + TestGradient(network, w, result, -3.0); } TEST_CASE("TestGradients_SUB", "[Train]") { gg::nn::Module network; - auto x = network.AddInput(1); - auto w = network.AddWeight(1); + auto x = network.AddInput(1, gg::NumericDataType::FLOAT64); + auto w = network.AddWeight(1, gg::NumericDataType::FLOAT64); auto result = x - w; - float x_data = 0.0f; - float w_data = 1.0f; + double x_data = 0.0; + double w_data = 1.0; x.data() = &x_data; w.data() = &w_data; // ∂/∂w (E - (x - w))^2 = 2(E - x + w) * ∂/∂w(E - x + w) = 2(E - x + w) * (0 - 0 + 1) = 2(E - x + w) // If E = 0, above equals 2(-x + w). If x = 0, w = 1, above equals 2. So after gradient update, // w should be 1 - 2 = -1.0f. - TestGradient(network, w, result, -1.0f); + TestGradient(network, w, result, -1.0); } TEST_CASE("TestTrainSimple", "[Train]") @@ -114,21 +118,21 @@ TEST_CASE("TestTrainSimple", "[Train]") auto w = network.AddWeight(4); auto L1 = w - x; gg::TrainingContext ctx = gg::CompileTrainingGraph(network, L1); - float x_data[] = { 1.0, 2.0, 3.0, 4.0 }; - float w_data[] = { -0.1, 0.1, -0.001, 0.0001 }; - float training_example_data[] = { 0.0, 0.0, 0.0, 0.0 }; + double x_data[] = {1.0, 2.0, 3.0, 4.0}; + double w_data[] = {-0.1, 0.1, -0.001, 0.0001}; + double training_example_data[] = {0.0, 0.0, 0.0, 0.0}; x.data() = x_data; w.data() = w_data; ctx.training_example = training_example_data; - float prev_loss = 1000; - for(int i = 0; i < 50; i++) + double prev_loss = 1000; + for (int i = 0; i < 50; i++) { ctx.Execute(); REQUIRE(*ctx.loss < prev_loss); } - for(int i = 0; i < 4; i++) + for (int i = 0; i < 4; i++) { - float pct_diff = (std::abs(w_data[i] - x_data[i]) / x_data[i]) * 100.0f; + double pct_diff = (std::abs(w_data[i] - x_data[i]) / x_data[i]) * 100.0; REQUIRE(pct_diff < 1); } } @@ -137,34 +141,34 @@ TEST_CASE("TestXor", "[Codegen]") { gg::Graph graph; auto x = graph.AddInput(2); - auto w1 = graph.AddInput({ 2, 2 }); - auto w2 = graph.AddInput({ 1, 2 }); - auto b1 = graph.AddInput({ 2, 1 }); + auto w1 = graph.AddInput({2, 2}); + auto w2 = graph.AddInput({1, 2}); + auto b1 = graph.AddInput({2, 1}); auto L1 = (w1 % x) > b1; - auto L2 = (w2 % L1) > 1.5f; + auto L2 = (w2 % L1) > 1.5; auto result = L2.Compile(); REQUIRE(L1.shape() == gg::Shape{2, 1}); REQUIRE(L2.shape() == gg::Shape{1, 1}); - float x_data[] = { 1.0, 1.0 }; - float w1_data[] = { 1.0, 1.0, -1.0, -1.0 }; - float b1_data[] = { 0.5, -1.5 }; - float w2_data[] = { 1.0, 1.0 }; + double x_data[] = {1.0, 1.0}; + double w1_data[] = {1.0, 1.0, -1.0, -1.0}; + double b1_data[] = {0.5, -1.5}; + double w2_data[] = {1.0, 1.0}; x.data() = x_data; w1.data() = w1_data; b1.data() = b1_data; w2.data() = w2_data; - for(bool x1 : { false, true }) + for (bool x1 : {false, true}) { - x_data[0] = x1 ? 1.0f : 0.0f; - for(bool x2 : { false, true }) + x_data[0] = x1 ? 1.0 : 0.0; + for (bool x2 : {false, true}) { - x_data[1] = x2 ? 1.0f : 0.0f; + x_data[1] = x2 ? 1.0f : 0.0; result.Execute(); - float expected = (x1 ^ x2) ? 1.0f : 0.0f; + double expected = (x1 ^ x2) ? 1.0 : 0.0; REQUIRE(result.data[0] == expected); } } @@ -172,24 +176,24 @@ TEST_CASE("TestXor", "[Codegen]") static std::default_random_engine Gen(0); -void RandomMatrix(float *m, size_t size_elts) +void RandomMatrix(double *m, size_t size_elts) { - std::uniform_real_distribution dist(-2.0f, 2.0f); - for(size_t i = 0; i < size_elts; i++) + std::uniform_real_distribution dist(-2.0, 2.0); + for (size_t i = 0; i < size_elts; i++) m[i] = dist(Gen); } -void NaiveMatmul(float *x, float *y, size_t A, size_t B, size_t C, float *result) +void NaiveMatmul(double *x, double *y, size_t A, size_t B, size_t C, double *result) { - for(size_t irow = 0; irow < A; irow++) + for (size_t irow = 0; irow < A; irow++) { - float *row = x + B * irow; - for(size_t icol = 0; icol < C; icol++) + double *row = x + B * irow; + for (size_t icol = 0; icol < C; icol++) { - float res = 0.0f; + double res = 0.0; - float *col = y + icol; - for(size_t i = 0; i < B; i++) + double *col = y + icol; + for (size_t i = 0; i < B; i++) { res += row[i] * col[C * i]; } @@ -201,7 +205,7 @@ void NaiveMatmul(float *x, float *y, size_t A, size_t B, size_t C, float *result TEST_CASE("TestMatmul", "[Codegen]") { constexpr size_t NumTrials = 10; - for(size_t itrial = 0; itrial < NumTrials; itrial++) + for (size_t itrial = 0; itrial < NumTrials; itrial++) { std::uniform_int_distribution dim_dist(1, 128); gg::dim_t A = dim_dist(Gen); @@ -210,40 +214,40 @@ TEST_CASE("TestMatmul", "[Codegen]") std::printf("Trial %zu: (%zu x %zu) * (%zu x %zu)\n", itrial, A, B, B, C); gg::Graph graph; - auto x = graph.AddInput({ A, B }); - auto y = graph.AddInput({ B, C }); + auto x = graph.AddInput({A, B}); + auto y = graph.AddInput({B, C}); auto result = (x % y).Compile(); - - x.data() = new float[A * B]; - y.data() = new float[B * C]; + + x.data() = new double[A * B]; + y.data() = new double[B * C]; RandomMatrix(x.data(), A * B); RandomMatrix(y.data(), B * C); result.Execute(); auto actual = result.data; - auto expected = new float[A * C]; + auto expected = new double[A * C]; NaiveMatmul(x.data(), y.data(), A, B, C, expected); - for(gg::dim_t i = 0; i < A * C; i++) + for (gg::dim_t i = 0; i < A * C; i++) { REQUIRE(std::abs(actual[i] - expected[i]) / actual[i] <= 0.02f); } // Make LeakSanitizer happy - delete [] x.data(); - delete [] y.data(); - delete [] expected; + delete[] x.data(); + delete[] y.data(); + delete[] expected; } } TEST_CASE("TestLogisticRegressionShape", "[Graph]") { gg::Graph graph; - auto x = graph.AddInput({ 28, 28 }).reshape({ 28 * 28, 1 }); - auto w1 = graph.AddInput({ 800, 28 * 28 }); - auto b1 = graph.AddInput({ 800, 1 }); + auto x = graph.AddInput({28, 28}).reshape({28 * 28, 1}); + auto w1 = graph.AddInput({800, 28 * 28}); + auto b1 = graph.AddInput({800, 1}); auto z1 = (w1 % x) + b1; auto a2 = gg::sigmoid(z1); - auto w2 = graph.AddInput({ 10, 800 }); - auto b2 = graph.AddInput({ 10, 1 }); + auto w2 = graph.AddInput({10, 800}); + auto b2 = graph.AddInput({10, 1}); auto result = (w2 % a2) + b2; REQUIRE(x.shape() == gg::Shape{28 * 28, 1}); REQUIRE(w1.shape() == gg::Shape{800, 28 * 28}); @@ -258,8 +262,8 @@ TEST_CASE("TestLogisticRegressionShape", "[Graph]") TEST_CASE("TestSimpleGraphShape", "[Graph]") { gigagrad::Graph graph; - auto tensor1 = graph.AddInput({ 2, 2 }); - auto tensor2 = graph.AddInput({ 2, 2 }); + auto tensor1 = graph.AddInput({2, 2}); + auto tensor2 = graph.AddInput({2, 2}); auto addition = tensor1 + tensor2; REQUIRE(addition->Kind() == gg::GraphNode::Kind::BinaryOp);