From d79ed89c784ad0e819fdc464e70dc3f860cbb074 Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Sun, 18 May 2025 17:38:56 +0800 Subject: [PATCH] =?UTF-8?q?rsubscalar:=E5=A2=9E=E5=8A=A0rsubscalar?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- doc/excuter/op-mem-cuda/list.md | 5 +- doc/excuter/op-mem-ompsimd/list.md | 5 +- .../src/deepx/tensorfunc/elementwise.hpp | 19 ++++- excuter/op-mem-cuda/src/client/tfs.cpp | 10 +++ .../tensorfunc/elementwise_miaobyte_basic.cu | 31 +++++++++ .../tensorfunc/elementwise_miaobyte_basic.cuh | 10 ++- .../tensorfunc/elementwise_miaobyte_basic.hpp | 12 ++++ .../src/deepx/tf/elementwise_basic.hpp | 69 +++++++++++++++++++ excuter/op-mem-ompsimd/src/client/tfs.cpp | 12 ++++ .../deepx/tensorfunc/elementwise_miaobyte.hpp | 21 ++++++ .../src/deepx/tf/elementwise.hpp | 61 ++++++++++++++++ front/py/deepx/nn/functional/authormap.py | 1 + .../nn/functional/leaffunc_elementwise.py | 35 +++++++--- .../py/deepx/nn/functional/rtf_elementwise.py | 7 ++ front/py/deepx/scheduler/client/udpconn.py | 2 +- front/py/deepx/tensor/elementwise.py | 12 ++++ front/py/examples/2_ir/2_elementwise_add.py | 12 +++- 17 files changed, 305 insertions(+), 19 deletions(-) diff --git a/doc/excuter/op-mem-cuda/list.md b/doc/excuter/op-mem-cuda/list.md index f11918d9..e5995829 100644 --- a/doc/excuter/op-mem-cuda/list.md +++ b/doc/excuter/op-mem-cuda/list.md @@ -57,6 +57,9 @@ | maxscalar | miaobyte | T3=max(T1, scalar) | maxscalar(tensor A, var scalar)->(tensor C) | | tan | miaobyte | T3=tan(T1) | tan(tensor A)->(tensor C) | | sin | miaobyte | T3=sin(T1) | sin(tensor A)->(tensor C) | +| less | miaobyte | mask=compare(T1, T2) | less(tensor A, tensor B)->(tensor mask) | +| powscalar | miaobyte | T3=pow(T1, scalar) | powscalar(tensor A, var scalar)->(tensor C) | +| rsubscalar | miaobyte | T3=scalar-T1 | rsubscalar(var scalar, tensor A)->(tensor C) | | divscalar | miaobyte | T3=scalar/T1 | divscalar(tensor A, var scalar)->(tensor C) | | log | miaobyte | T3=log(T1) | log(tensor A)->(tensor C) | | addscalar | miaobyte | T3=T1+scalar | addscalar(tensor A, var b)->(tensor C) | @@ -67,8 +70,6 @@ | minscalar | miaobyte | T3=min(T1, scalar) | minscalar(tensor A, var scalar)->(tensor C) | | rpowscalar | miaobyte | T3=pow(scalar, T1) | rpowscalar(var scalar, tensor A)->(tensor C) | | rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var scalar, tensor A)->(tensor C) | -| less | miaobyte | mask=compare(T1, T2) | less(tensor A, tensor B)->(tensor mask) | -| powscalar | miaobyte | T3=pow(T1, scalar) | powscalar(tensor A, var scalar)->(tensor C) | | todtype | none | T3(dtypeA)->T1(dtypeB) | todtype(tensor a)->(tensor b) | | add | cublas | T3=T1+T2 | add(tensor a, tensor b)->(tensor c) | | add | miaobyte | T3=T1+T2 | add(tensor a, tensor b)->(tensor c) | diff --git a/doc/excuter/op-mem-ompsimd/list.md b/doc/excuter/op-mem-ompsimd/list.md index 9e16aea4..fe44dd52 100644 --- a/doc/excuter/op-mem-ompsimd/list.md +++ b/doc/excuter/op-mem-ompsimd/list.md @@ -58,6 +58,9 @@ | maxscalar | miaobyte | T3=max(T1,scalar) | maxscalar(tensor A, var scalar)->(tensor C) | | tan | miaobyte | T3=tan(T1) | tan(tensor A)->(tensor C) | | sin | miaobyte | T3=sin(T1) | sin(tensor A)->(tensor C) | +| less | miaobyte | mask=less(T1,T2) | less(tensor A, tensor B)->(tensor mask) | +| powscalar | miaobyte | T3=T1^scalar | powscalar(tensor A, var scalar)->(tensor C) | +| rsubscalar | miaobyte | T3=scalar-T1 | rsubscalar(var scalar, tensor a)->(tensor c) | | divscalar | miaobyte | T3=T1/scalar | divscalar(tensor A, var scalar)->(tensor C) | | log | miaobyte | T3=log(T1) | log(tensor A)->(tensor C) | | addscalar | miaobyte | T3=T1+scalar | addscalar(tensor a, var scalar)->(tensor c) | @@ -68,8 +71,6 @@ | minscalar | miaobyte | T3=min(T1,scalar) | minscalar(tensor A, var scalar)->(tensor C) | | rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var scalar, tensor A)->(tensor C) | | rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var scalar, tensor A)->(tensor C) | -| less | miaobyte | mask=less(T1,T2) | less(tensor A, tensor B)->(tensor mask) | -| powscalar | miaobyte | T3=T1^scalar | powscalar(tensor A, var scalar)->(tensor C) | | todtype | none | T3(dtypeA)->T1(dtypeB) | todtype(tensor A)->(tensor C) | | add | cblas | T3=T1+T2 | add(tensor a, tensor b)->(tensor c) | | add | miaobyte | T3=T1+T2 | add(tensor a, tensor b)->(tensor c) | diff --git a/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp b/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp index a33ca60c..a708268b 100644 --- a/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp +++ b/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp @@ -58,6 +58,7 @@ namespace deepx::tensorfunc subDispatcher::sub(A, B, C); } + // A-scalar=>C template struct subscalarDispatcher { @@ -66,20 +67,34 @@ namespace deepx::tensorfunc throw NotImplementError("subscalar"); } }; - - // A-scalar=>C template void subscalar(const Tensor &input, const T value, Tensor &output) { subscalarDispatcher::subscalar(input, value, output); } + + + //scalar-A=>C + template + struct rsubscalarDispatcher + { + static void rsubscalar(const T value, const Tensor &input, Tensor &output) = delete; + }; + template + void rsubscalar(const T value, const Tensor &input, Tensor &output) + { + rsubscalarDispatcher::rsubscalar(value, input, output); + } + + template struct mulDispatcher { static void mul(const Tensor &A, const Tensor &B, Tensor &C) = delete; }; + // A*B=>C template void mul(const Tensor &A, const Tensor &B, Tensor &C) diff --git a/excuter/op-mem-cuda/src/client/tfs.cpp b/excuter/op-mem-cuda/src/client/tfs.cpp index 1192bb74..85fd5be6 100644 --- a/excuter/op-mem-cuda/src/client/tfs.cpp +++ b/excuter/op-mem-cuda/src/client/tfs.cpp @@ -232,6 +232,16 @@ namespace deepx::tf { Param("C", DataCategory::Tensor, Precision::Any), }))); + tffactory.add_tf(std::make_shared>(vector( + { + Param("scalar", DataCategory::Var, Precision::Any), + Param("A", DataCategory::Tensor, Precision::Any), + }), + vector( + { + Param("C", DataCategory::Tensor, Precision::Any), + }))); + tffactory.add_tf(std::make_shared>(vector( { Param("A", DataCategory::Tensor, Precision::Any), diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu index d0c996a1..16011b37 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu @@ -225,6 +225,37 @@ namespace deepx::tensorfunc template void launch_subscalar(const int16_t *a, const int16_t scalar, int16_t *c, const int size); template void launch_subscalar(const int8_t *a, const int8_t scalar, int8_t *c, const int size); + // rsubscalar + template + __global__ void rsubscalar_kernel(const T scalar, const T* A, T* C,const int size){ + int stride = blockDim.x * gridDim.x; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride) + { + C[idx] = scalar - A[idx]; + } + } + + template + void launch_rsubscalar(const T scalar, const T* a, T* c,const int size){ + auto [numBlocks, blockSize] = BestDims(size); + rsubscalar_kernel<<>>(scalar, a, c, size); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + throw std::runtime_error("Failed to launch rsubscalar kernel: "+std::string(cudaGetErrorString(err))); + } + } + template void launch_rsubscalar(const double scalar, const double* a, double* c,const int size); + template void launch_rsubscalar(const float scalar, const float* a, float* c,const int size); + template void launch_rsubscalar(const half scalar, const half* a, half* c,const int size); + template void launch_rsubscalar(const nv_bfloat16 scalar, const nv_bfloat16* a, nv_bfloat16* c,const int size); + template void launch_rsubscalar(const int64_t scalar, const int64_t* a, int64_t* c,const int size); + template void launch_rsubscalar(const int32_t scalar, const int32_t* a, int32_t* c,const int size); + template void launch_rsubscalar(const int16_t scalar, const int16_t* a, int16_t* c,const int size); + template void launch_rsubscalar(const int8_t scalar, const int8_t* a, int8_t* c,const int size); + + + // mul template __global__ void mul_kernel(const T *A, const T *B, T *C, const int size) diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cuh b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cuh index 4100f38d..a943fcfc 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cuh +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cuh @@ -44,7 +44,15 @@ namespace deepx::tensorfunc template void launch_subscalar(const T* a, const T scalar, T* c,const int size); - + + + // rsubscalar + template + __global__ void rsubscalar_kernel(const T scalar, const T* A, T* C,const int size); + + template + void launch_rsubscalar(const T scalar, const T* a, T* c,const int size); + // mul template __global__ void mul_kernel(const T* A, const T* B, T* C,const int size); diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.hpp b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.hpp index b7d3a680..50fd1ade 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.hpp +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.hpp @@ -69,6 +69,18 @@ namespace deepx::tensorfunc } }; + template + struct rsubscalarDispatcher + { + static void rsubscalar(const T scalar, const Tensor &A, Tensor &C) + { + if (A.shape.size != C.shape.size) { + throw TensorShapeError("rsubscalar"); + } + launch_rsubscalar(scalar, A.data, C.data, A.shape.size); + } + }; + template struct mulDispatcher { diff --git a/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp b/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp index d3659af3..c0472001 100644 --- a/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp +++ b/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp @@ -636,6 +636,75 @@ namespace deepx::tf } }; + // rsubscalar + template + class RSubScalar : public TF + { + public: + RSubScalar(const vector &args, const vector &returns) + { + this->name = "rsubscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + + string math_formula() const override + { + return "T3=scalar-T1"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) + { + return 1; + } + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (a_type != c_type) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(c_type); + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float16: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::BFloat16: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; + } + return 0; + } + }; + template class Mul : public TF { diff --git a/excuter/op-mem-ompsimd/src/client/tfs.cpp b/excuter/op-mem-ompsimd/src/client/tfs.cpp index b76faac8..04b09aeb 100644 --- a/excuter/op-mem-ompsimd/src/client/tfs.cpp +++ b/excuter/op-mem-ompsimd/src/client/tfs.cpp @@ -246,6 +246,18 @@ namespace deepx::tf { Param("c", DataCategory::Tensor, Precision::Any), }))); + // rsubscalar author=miaobyte + tffactory.add_tf(std::make_shared>(vector( + { + Param("scalar", DataCategory::Var, Precision::Any), + Param("a", DataCategory::Tensor, Precision::Any), + }), + vector( + { + Param("c", DataCategory::Tensor, Precision::Any), + }))); + + // mul author=miaobyte tffactory.add_tf(std::make_shared>(vector( { diff --git a/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp b/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp index 517cbd9e..7f470467 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp @@ -194,6 +194,27 @@ namespace deepx::tensorfunc } }; + template + struct rsubscalarDispatcher + { + static void rsubscalar(const T scalar, const Tensor &A, Tensor &C) + { + elementwise_A_b_C(A, scalar, C, + // 标量操作 + [](const T &a,const T &scalar, T &c) + { c = scalar - a; }, + // SIMD操作 + []( const T *a,const T scalar, T *c, size_t size) + { + const ScalableTag tag; + auto vec1 = Load(tag, a); + auto vec_scalar = Set(tag, scalar); + auto vec_result = Sub(vec_scalar, vec1); + Store(vec_result, tag, c); + }); + } + }; + // 添加 mul 的模板特化实现 template struct mulDispatcher diff --git a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp index 4cddae4d..08964ee9 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp @@ -495,6 +495,67 @@ namespace deepx::tf return 0; } }; + + template + class RSubScalar : public TF + { + public: + RSubScalar(vector args, vector returns) + { + this->name = "rsubscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override + { + return "T3=scalar-T1"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[1].textvalue,this->returns[0].textvalue}, mem, error)!=0) + { + return 1; + } + Precision a_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; + Precision c_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (a_type != c_type) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(c_type); + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; + } + return 0; + } + }; template class Mul : public TF diff --git a/front/py/deepx/nn/functional/authormap.py b/front/py/deepx/nn/functional/authormap.py index 9972f304..696a8d5d 100644 --- a/front/py/deepx/nn/functional/authormap.py +++ b/front/py/deepx/nn/functional/authormap.py @@ -11,6 +11,7 @@ 'addscalar':'miaobyte', 'sub':'miaobyte', 'subscalar':'miaobyte', + 'rsubscalar':'miaobyte', 'mul':'miaobyte', 'mulscalar':'miaobyte', 'div':'miaobyte', diff --git a/front/py/deepx/nn/functional/leaffunc_elementwise.py b/front/py/deepx/nn/functional/leaffunc_elementwise.py index a883a6e8..26956952 100644 --- a/front/py/deepx/nn/functional/leaffunc_elementwise.py +++ b/front/py/deepx/nn/functional/leaffunc_elementwise.py @@ -7,18 +7,26 @@ #四则运算 add = create_A_B_tf_C('add') -sub = create_A_B_tf_C('sub') -mul = create_A_B_tf_C('mul') -_div=create_A_B_tf_C('div') - -def div( - a: Union[Tensor, float, int], - b: Union[Tensor, float, int], +_sub=create_A_B_tf_C('sub') +def rsub(a:float,b:Tensor,out:Union[Tensor,str]=None)->Tensor: + outtensor=out + if isinstance(out,str) or out is None: + outtensor=newtensor(b.shape,dtype=b.dtype,name=out) + from .rtf_elementwise import rtf_rsubscalar + rtf_rsubscalar(a,b,outtensor,defaultauthor['rsubscalar']) + return outtensor +def sub(a:Union[Tensor,float,int], + b:Union[Tensor,float,int], out:Union[Tensor,str]=None)->Tensor: if isinstance(a,Tensor): - return _div(a,b,out) + return _sub(a,b,out) elif isinstance(a,float) or isinstance(a,int): - return rdiv(a,b,out) + return rsub(a,b,out) + + +mul = create_A_B_tf_C('mul') +_div=create_A_B_tf_C('div') + def rdiv( a: Union[float, int], b: Tensor, @@ -29,6 +37,15 @@ def rdiv( from .rtf_elementwise import rtf_rdivscalar rtf_rdivscalar(a,b,outtensor,defaultauthor['rdivscalar']) return outtensor +def div( + a: Union[Tensor, float, int], + b: Union[Tensor, float, int], + out:Union[Tensor,str]=None)->Tensor: + if isinstance(a,Tensor): + return _div(a,b,out) + elif isinstance(a,float) or isinstance(a,int): + return rdiv(a,b,out) + ## 幂、指数 运算 diff --git a/front/py/deepx/nn/functional/rtf_elementwise.py b/front/py/deepx/nn/functional/rtf_elementwise.py index 34ab28e9..fe707f84 100644 --- a/front/py/deepx/nn/functional/rtf_elementwise.py +++ b/front/py/deepx/nn/functional/rtf_elementwise.py @@ -20,6 +20,13 @@ def rtf_subscalar(a:Tensor, b:float, out:Tensor, author='miaobyte')->Tensor: A_scalar_op_C("subscalar",a,b,out,author) return out +def rtf_rsubscalar(a:float, b:Tensor, out:Tensor, author='miaobyte')->Tensor: + args = [ Param.varnum(a),Param.tensor(b)] + returns = [Param.tensor(out)] + ir = DeepxIR("rsubscalar", args, returns, author) + send(ir) + return out + def rtf_mul(a:Tensor, b:Tensor, out:Tensor, author='miaobyte')->Tensor: A_B_op_C("mul",a,b,out,author) return out diff --git a/front/py/deepx/scheduler/client/udpconn.py b/front/py/deepx/scheduler/client/udpconn.py index a25b0963..6a12c26a 100644 --- a/front/py/deepx/scheduler/client/udpconn.py +++ b/front/py/deepx/scheduler/client/udpconn.py @@ -3,7 +3,7 @@ import select class UDPConn: - def __init__(self, endpoint: str = "localhost:8080"): + def __init__(self, endpoint: str = "localhost:9090"): # 解析endpoint self._host, port_str = endpoint.split(':') self._port = int(port_str) diff --git a/front/py/deepx/tensor/elementwise.py b/front/py/deepx/tensor/elementwise.py index acf587da..3aafd7ab 100644 --- a/front/py/deepx/tensor/elementwise.py +++ b/front/py/deepx/tensor/elementwise.py @@ -27,6 +27,18 @@ def sub_(self, other:Union[Tensor,float,int]): from deepx.nn.functional import sub as sub_func sub_func(self,other,self) +@tensor_method +def rsub(self,other:Union[Tensor,float,int], + out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import sub as sub_func + return sub_func(other,self,out) + +@tensor_method +def rsub_(self,other:Union[Tensor,float,int]): + from deepx.nn.functional import sub as sub_func + sub_func(other,self,self) + return self + @tensor_method def mul(self, other:Union[Tensor,float,int], out:Union[Tensor,str]='')->Tensor: diff --git a/front/py/examples/2_ir/2_elementwise_add.py b/front/py/examples/2_ir/2_elementwise_add.py index d0b016bf..0b42fc3d 100644 --- a/front/py/examples/2_ir/2_elementwise_add.py +++ b/front/py/examples/2_ir/2_elementwise_add.py @@ -1,3 +1,4 @@ +print() ############-------PyTorch-------################ import torch @@ -5,8 +6,11 @@ torch_t2 = torch_t1.clone() torch_t3 = torch_t1 + torch_t2 torch_t3.add_(0.5) -print() + print(torch_t3) +torch_t4 = torch.full((2,3,4), 1.5, dtype=torch.float32) +torch_t5 = 2-torch_t4 +print(torch_t5) ############-------DEEPX-------################ @@ -18,4 +22,8 @@ t2 = t1.clone() t3 = t1+t2 t3.add_(0.5) -t3.print() \ No newline at end of file +t3.print() + +t4 = full((2,3,4), value=1.5,dtype="float32") +t5 = 2-t4 +t5.print() \ No newline at end of file