From f71f809bfd2b3ce3280d5ff440c854ad891d23b0 Mon Sep 17 00:00:00 2001 From: miaobyte <734991033@qq.com> Date: Thu, 20 Mar 2025 16:04:06 +0800 Subject: [PATCH] excuter(cpu/cuda):subscalar --- doc/excuter/op-mem-cuda/list.md | 2 + doc/excuter/op-mem-ompsimd/list.md | 1 + .../src/deepx/tensorfunc/elementwise.hpp | 12 +- excuter/op-mem-cuda/src/client/tfs.cpp | 13 +- .../tensorfunc/elementwise_miaobyte_basic.cu | 29 ++- .../tensorfunc/elementwise_miaobyte_basic.cuh | 33 ++- .../tensorfunc/elementwise_miaobyte_basic.hpp | 14 ++ .../src/deepx/tf/elementwise_basic.hpp | 82 ++++++- excuter/op-mem-ompsimd/src/client/tfs.cpp | 10 + .../src/deepx/tf/elementwise.hpp | 222 +++++++++++------- 10 files changed, 320 insertions(+), 98 deletions(-) diff --git a/doc/excuter/op-mem-cuda/list.md b/doc/excuter/op-mem-cuda/list.md index 171779a4..9913a248 100644 --- a/doc/excuter/op-mem-cuda/list.md +++ b/doc/excuter/op-mem-cuda/list.md @@ -5,8 +5,10 @@ | Operation | Author | Func Def | Math Formula | IR Instruction | |-----------|--------|------------|--------------|----------------| | addscalar | miaobyte | addscalar(tensor A, var b)->(tensor C) | T3=T1+scalar | addscalar(tensor A, var b)->(tensor C) | +| add | cublas | add(tensor a, tensor b)->(tensor c) | T3=T1+T2 | add(tensor a, tensor b)->(tensor c) | | add | miaobyte | add(tensor a, tensor b)->(tensor c) | T3=T1+T2 | add(tensor a, tensor b)->(tensor c) | | uniform | miaobyte | uniform(tensor t, var low, var high, var seed)->() | uniform(T1,low,high,seed) | uniform(tensor t, var low, var high, var seed)->() | +| subscalar | miaobyte | subscalar(tensor A, var b)->(tensor C) | T3=T1-scalar | subscalar(tensor A, var b)->(tensor C) | | arange | miaobyte | arange(tensor t, var start, var step)->() | arange(T1,start,step) | arange(tensor t, var start, var step)->() | | constant | miaobyte | constant(tensor t, var value)->() | constant(T1) | constant(tensor t, var value)->() | | print | miaobyte | print(tensor )->() | print(T1) | print(tensor )->() | diff --git a/doc/excuter/op-mem-ompsimd/list.md b/doc/excuter/op-mem-ompsimd/list.md index 581ab8f9..f10183f4 100644 --- a/doc/excuter/op-mem-ompsimd/list.md +++ b/doc/excuter/op-mem-ompsimd/list.md @@ -9,6 +9,7 @@ | add | cblas | add(tensor a, tensor b)->(tensor c) | T3=T1+T2 | add(tensor a, tensor b)->(tensor c) | | add | miaobyte | add(tensor a, tensor b)->(tensor c) | T3=T1+T2 | add(tensor a, tensor b)->(tensor c) | | uniform | miaobyte | uniform(tensor t, var low, var high, var seed)->() | uniform(T1,low,high,seed) | uniform(tensor t, var low, var high, var seed)->() | +| subscalar | miaobyte | subscalar(tensor a, var scalar)->(tensor c) | T3=T1-scalar | subscalar(tensor a, var scalar)->(tensor c) | | arange | miaobyte | arange(tensor t, var start, var step)->() | arange(T1,start,step) | arange(tensor t, var start, var step)->() | | constant | miaobyte | constant(tensor t, var value)->() | constant(T1,value) | constant(tensor t, var value)->() | | print | miaobyte | print(tensor )->() | print(T1) | print(tensor )->() | diff --git a/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp b/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp index bf6fd053..e05506f7 100644 --- a/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp +++ b/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp @@ -24,7 +24,9 @@ namespace deepx::tensorfunc template struct addscalarDispatcher { - static void addscalar(const Tensor &input, const T value, Tensor &output) = delete; + static void addscalar(const Tensor &input, const T value, Tensor &output){ + throw NotImplementError("addscalar"); + } }; template @@ -36,7 +38,9 @@ namespace deepx::tensorfunc template struct subDispatcher { - static void sub(const Tensor &A, const Tensor &B, Tensor &C) = delete; + static void sub(const Tensor &A, const Tensor &B, Tensor &C){ + throw NotImplementError("sub"); + } }; template @@ -48,7 +52,9 @@ namespace deepx::tensorfunc template struct subscalarDispatcher { - static void subscalar(const Tensor &input, const T value, Tensor &output) = delete; + static void subscalar(const Tensor &input, const T value, Tensor &output){ + throw NotImplementError("subscalar"); + } }; template diff --git a/excuter/op-mem-cuda/src/client/tfs.cpp b/excuter/op-mem-cuda/src/client/tfs.cpp index 41108a7f..cfcbec3b 100644 --- a/excuter/op-mem-cuda/src/client/tfs.cpp +++ b/excuter/op-mem-cuda/src/client/tfs.cpp @@ -114,7 +114,7 @@ namespace deepx::tf { Param("c", DataCategory::Tensor, Precision::Any), }))); - tffactory.add_tf(std::make_shared>(vector( + tffactory.add_tf(std::make_shared>(vector( { Param("A", DataCategory::Tensor, Precision::Any), Param("b", DataCategory::Var, Precision::Any), @@ -133,7 +133,16 @@ namespace deepx::tf { Param("C", DataCategory::Tensor, Precision::Any), }))); - + tffactory.add_tf(std::make_shared>(vector( + { + Param("A", DataCategory::Tensor, Precision::Any), + Param("b", DataCategory::Var, Precision::Any), + }), + vector( + { + Param("C", DataCategory::Tensor, Precision::Any), + }))); + // opfactory.add_op(Sub_cblas()); // opfactory.add_op(Sub_cblas()); diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu index f66950ac..f4836cd6 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu @@ -105,7 +105,34 @@ namespace deepx::tensorfunc template void launch_sub(const int numBlocks, const int blockSize, const int16_t* a, const int16_t* b, int16_t* c, const int size); template void launch_sub(const int numBlocks, const int blockSize, const int8_t* a, const int8_t* b, int8_t* c, const int size); - + template + __global__ void subscalar_kernel(const T* A, const T scalar, T* C,const int size){ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + C[idx] = A[idx] - scalar; + } + } + template __global__ void subscalar_kernel(const double* A, const double scalar, double* C,const int size); + template __global__ void subscalar_kernel(const float* A, const float scalar, float* C,const int size); + template __global__ void subscalar_kernel(const half* A, const half scalar, half* C,const int size); + template __global__ void subscalar_kernel(const nv_bfloat16* A, const nv_bfloat16 scalar, nv_bfloat16* C,const int size); + template __global__ void subscalar_kernel(const int64_t* A, const int64_t scalar, int64_t* C,const int size); + template __global__ void subscalar_kernel(const int32_t* A, const int32_t scalar, int32_t* C,const int size); + template __global__ void subscalar_kernel(const int16_t* A, const int16_t scalar, int16_t* C,const int size); + template __global__ void subscalar_kernel(const int8_t* A, const int8_t scalar, int8_t* C,const int size); + + template + void launch_subscalar(const int numBlocks, const int blockSize, const T* a, const T scalar, T* c, const int size) { + subscalar_kernel<<>>(a, scalar, c, size); + } + template void launch_subscalar(const int numBlocks, const int blockSize, const double* a, const double scalar, double* c, const int size); + template void launch_subscalar(const int numBlocks, const int blockSize, const float* a, const float scalar, float* c, const int size); + template void launch_subscalar(const int numBlocks, const int blockSize, const half* a, const half scalar, half* c, const int size); + template void launch_subscalar(const int numBlocks, const int blockSize, const nv_bfloat16* a, const nv_bfloat16 scalar, nv_bfloat16* c, const int size); + template void launch_subscalar(const int numBlocks, const int blockSize, const int64_t* a, const int64_t scalar, int64_t* c, const int size); + template void launch_subscalar(const int numBlocks, const int blockSize, const int32_t* a, const int32_t scalar, int32_t* c, const int size); + template void launch_subscalar(const int numBlocks, const int blockSize, const int16_t* a, const int16_t scalar, int16_t* c, const int size); + template void launch_subscalar(const int numBlocks, const int blockSize, const int8_t* a, const int8_t scalar, int8_t* c, const int size); } #endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_BASIC_CUH diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cuh b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cuh index 77102fc9..966cfa1c 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cuh +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cuh @@ -103,7 +103,38 @@ namespace deepx::tensorfunc template <> void launch_sub(int numBlocks, int blockSize, const int8_t* a, const int8_t* b, int8_t* c,const int size); - + + // subscalar + template + __global__ void subscalar_kernel(const T* A, const T scalar, T* C,const int size); + + template + void launch_subscalar(const int numBlocks, const int blockSize, const T* a, const T scalar, T* c,const int size); + + template <> + void launch_subscalar(const int numBlocks, const int blockSize, const double* a, const double scalar, double* c,const int size); + + template <> + void launch_subscalar(const int numBlocks, const int blockSize, const float* a, const float scalar, float* c,const int size); + + template <> + void launch_subscalar(const int numBlocks, const int blockSize, const nv_bfloat16* a, const nv_bfloat16 scalar, nv_bfloat16* c,const int size); + + template <> + void launch_subscalar<__half>(const int numBlocks, const int blockSize, const __half* a, const __half scalar, __half* c,const int size); + + template <> + void launch_subscalar(const int numBlocks, const int blockSize, const int64_t* a, const int64_t scalar, int64_t* c,const int size); + + template <> + void launch_subscalar(const int numBlocks, const int blockSize, const int32_t* a, const int32_t scalar, int32_t* c,const int size); + + template <> + void launch_subscalar(const int numBlocks, const int blockSize, const int16_t* a, const int16_t scalar, int16_t* c,const int size); + + template <> + void launch_subscalar(const int numBlocks, const int blockSize, const int8_t* a, const int8_t scalar, int8_t* c,const int size); + } #endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_BASIC_CUH diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.hpp b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.hpp index 2da8ec9c..0500dd60 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.hpp +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.hpp @@ -55,6 +55,20 @@ namespace deepx::tensorfunc launch_sub(numBlocks, blockSize, A.data, B.data, C.data, A.shape.size); } }; + + template + struct subscalarDispatcher + { + static void subscalar(const Tensor &A, const T scalar, Tensor &C) + { + if (A.shape.size != C.shape.size) { + throw TensorShapeError("subscalar"); + } + const int blockSize = A.shape.size > 256 ? 256 : A.shape.size; + int numBlocks = (A.shape.size + blockSize - 1) / blockSize; + launch_subscalar(numBlocks, blockSize, A.data, scalar, C.data, A.shape.size); + } + }; } #endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_BASIC_HPP diff --git a/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp b/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp index 218432a8..c0910a99 100644 --- a/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp +++ b/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp @@ -83,10 +83,10 @@ namespace deepx::tf }; template - class Addscalar : public TF + class AddScalar : public TF { public: - Addscalar(const vector &args, const vector &returns) + AddScalar(const vector &args, const vector &returns) { this->name = "addscalar"; this->author = Author::name(); @@ -94,7 +94,7 @@ namespace deepx::tf this->returns = returns; } - Addscalar(string text) + AddScalar(string text) { this->parse(text); this->author = Author::name(); @@ -109,7 +109,7 @@ namespace deepx::tf } shared_ptr clone() const override { - return make_shared>(*this); + return make_shared>(*this); } int run(shared_ptr mem, string &error) override { @@ -226,6 +226,80 @@ namespace deepx::tf return 0; } }; + + template + class SubScalar : public TF + { + public: + SubScalar(const vector &args, const vector &returns) + { + this->name = "subscalar"; + this->author = Author::name(); + this->args = args; + this->returns = returns; + } + + SubScalar(string text) + { + this->parse(text); + this->author = Author::name(); + if (this->name != "subscalar") + { + throw std::runtime_error("Invalid name: " + this->name); + } + } + string math_formula() const override + { + return "T3=T1-scalar"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + int run(shared_ptr mem, string &error) override + { + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (a_type != c_type) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(c_type); + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float16: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::BFloat16: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; + } + return 0; + } + }; + + }; #endif // DEEPX_TF_ELEMENTWISE_BASIC_HPP diff --git a/excuter/op-mem-ompsimd/src/client/tfs.cpp b/excuter/op-mem-ompsimd/src/client/tfs.cpp index 6bae8e79..b2de5145 100644 --- a/excuter/op-mem-ompsimd/src/client/tfs.cpp +++ b/excuter/op-mem-ompsimd/src/client/tfs.cpp @@ -140,6 +140,16 @@ namespace deepx::tf { Param("c", DataCategory::Tensor, Precision::Any), }))); + + tffactory.add_tf(std::make_shared>(vector( + { + Param("a", DataCategory::Tensor, Precision::Any), + Param("scalar", DataCategory::Var, Precision::Any), + }), + vector( + { + Param("c", DataCategory::Tensor, Precision::Any), + }))); // opfactory.add_op(Addscalar_miaobyte()); // opfactory.add_op(Addscalar_miaobyte()); diff --git a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp index 1702c644..5487a2a7 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp @@ -8,11 +8,13 @@ #include "deepx/tensorfunc/authors.hpp" #include "deepx/tensorfunc/elementwise_miaobyte.hpp" #include "deepx/tensorfunc/elementwise_cblas.hpp" -namespace deepx::tf { +namespace deepx::tf +{ template - class Add : public TF { - public: + class Add : public TF + { + public: Add(vector args, vector returns) { this->name = "add"; @@ -23,7 +25,7 @@ namespace deepx::tf { string math_formula() const override { return "T3=T1+T2"; - } + } shared_ptr clone() const override { return make_shared>(*this); @@ -38,38 +40,38 @@ namespace deepx::tf { error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(b_type) + " != " + precision_str(c_type); return 1; } - switch (a_type) + switch (a_type) { - case Precision::Float64: - tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Float32: - tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int64: - tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int32: - tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int16: - tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int8: - tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - default: - error = "Unsupported dtype: " + precision_str(a_type); - return 1; + case Precision::Float64: + tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; } return 0; } - }; template - class AddScalar : public TF { - public: + class AddScalar : public TF + { + public: AddScalar(vector args, vector returns) { this->name = "addscalar"; @@ -80,7 +82,7 @@ namespace deepx::tf { string math_formula() const override { return "T3=T1+scalar"; - } + } shared_ptr clone() const override { return make_shared>(*this); @@ -94,37 +96,37 @@ namespace deepx::tf { error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(c_type); return 1; } - switch (a_type) + switch (a_type) { - case Precision::Float64: - tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Float32: - tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int64: - tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int32: - tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int16: - tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int8: - tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); - break; - default: - error = "Unsupported dtype: " + precision_str(a_type); - return 1; + case Precision::Float64: + tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; } return 0; } - }; template - class Sub : public TF { - public: + class Sub : public TF + { + public: Sub(vector args, vector returns) { this->name = "sub"; @@ -135,7 +137,7 @@ namespace deepx::tf { string math_formula() const override { return "T3=T1-T2"; - } + } shared_ptr clone() const override { return make_shared>(*this); @@ -150,43 +152,89 @@ namespace deepx::tf { error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(b_type) + " != " + precision_str(c_type); return 1; } - switch (a_type) + switch (a_type) { - case Precision::Float64: - tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Float32: - tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int64: - tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int32: - tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int16: - tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Int8: - tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - default: - error = "Unsupported dtype: " + precision_str(a_type); - return 1; + case Precision::Float64: + tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; } return 0; } - }; - + template + class SubScalar : public TF + { + public: + SubScalar(vector args, vector returns) + { + this->name = "subscalar"; + this->author = Author::name(); + this->args = args; + this->returns = returns; + } + string math_formula() const override + { + return "T3=T1-scalar"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + int run(shared_ptr mem, string &error) override + { + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (a_type != c_type) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(c_type); + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; + } + return 0; + } + }; } - - - - - - - #endif