diff --git a/README.md b/README.md index a590faa9..8c24b057 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,19 @@ # deepx -deepx提出了一种原生分布式自动并行的训推一体化的深度学习框架。 +deepx提出了一种以IR计算图为核心的原生分布式自动并行的训推一体化的深度学习框架,以IR计算图为核心,经过多层级等价替换,实现从简单的数学形式的计算图,自适应等价替换为分布式、并行、自动反向的工程系统架构。 ## 一.deepx概述 -deepx的执行支持eager和auto两种模式 +deepx的分为前端表达侧,编译替换层,执行器层 -+ eager立即执行函数 -+ auto则会经过计算图编译器优化器 ++ 前端表达侧,交由算法工程师、用接近数学的表达方式,设计其数学计算过程。只表示为单节点、单线程的简介数学表达过程,不设计复杂的device类型、计算节点数量等。 ++ 编译替换层:注册了多轮不同类型的IR编译器,实现等价替换,可以以插件的形式增加自定义能力如定制kvcache,实现对计算图进行局部替换,获得新的能力。 ++ 执行器层:实现真正的tensor运算,大规模并行化。 ### 前端 python sdk提供接近pytorch的API -也容许其他语言的sdk接入 +也容许其他语言的sdk接入, + IR通信调度。不同于pytorch或其他py+bind c++这种单一进程的栈上函数调度执行的方式。deepx各个程序(如front的python sdk,back的计算图编译器优化器、excuter如ompsimd)之间,通过IR实现网络通信调度,需要各自启动对应进程。 @@ -21,36 +22,43 @@ python sdk提供接近pytorch的API |--------------|-----------------------|-------------------------| | 执行模式 | 单进程内函数栈调度 | 多进程分布式协同 | | 通信方式 | 内存直接访问 | IR网络计算调度协议交换 | -| 组件耦合度 | 紧耦合(Python绑定C++)| 松耦合(gRPC/自定义协议)| +| 组件耦合度 | 紧耦合(Python绑定C++)| 松耦合| +| tensor生命周期管理 | 由python侧控制 | 由deltensor这个IR指令,显示管理tensor| -### 调度面 +### 编译替换层 + 注册中心:收集当前已就绪的执行器的算子列表,收集算子时耗和空间占用信息 + 计算图编译器优化器:fusion算子,计算图节点消除,自动生成tensor拆分并行的计算子图并替代原节点 + 执行调度器:数据并行,流水线并行(前向反向并行),模型并行。 + front生成基础IR,编译器负责进行fusion成excuter注册的高级算子。 -### 执行器 +### 执行层 -负责低级的算子计算操作,以Op为执行的核心单元 +执行层包括op和mem两种执行器,但实际实现时,当前只设计了一个程序同时负责op和mem的管理。 + +负责低级的算子计算操作,以IR为执行的核心单元 ``` -Op{args(args_grad),returns(returns_grad)|func forward,backward} +Op{args(args_grad),returns(returns_grad)|func run} ``` -大部分Op都需要同时实现forward和backward,但也有部分只为推理设计的fusionOp可以根据需要实现forward。 +Op需要实现run方法 关于excuter,只要能按deepxIR序列执行,并返回结果,就可以接入deepx分布式调度框架,因此,从硬件、指令、加速库、高级框架包括训练、推理引擎,都可以稍作修改,就接入deepx体系。 +当前的 + #### 默认执行器 + cpu执行器,已实现ompsimd。其支持的算子列表[ompsimd](doc/excuter/op-mem-ompsimd/list.md) #### GPU执行器 -+ cuda执行器【实现中状态】 ++ cuda执行器,其支持的算子列表[cuda](doc/excuter/op-mem-cuda/list.md) + 欢迎大家提交cuda代码 + rocm - ++ apple ++ 其他硬件加速器 #### 张量计算框架or函数级执行器 diff --git a/doc/excuter/op-mem-cuda/list.md b/doc/excuter/op-mem-cuda/list.md index 475d0f43..f25b6973 100644 --- a/doc/excuter/op-mem-cuda/list.md +++ b/doc/excuter/op-mem-cuda/list.md @@ -40,6 +40,7 @@ | Operation | Author | Math Formula | IR Instruction | |-----------|--------|--------------|----------------| | normal | miaobyte | normal(mean,stddev,seed)->T1 | normal(var mean, var stddev, var seed)->(tensor t) | +| dropout | miaobyte | dropout(p,seed)->A | dropout(var p, var seed)->(tensor A) | | uniform | miaobyte | uniform(low,high,seed)->T1 | uniform(var low, var high, var seed)->(tensor t) | | arange | miaobyte | arange(start,step)->T1 | arange(var start, var step)->(tensor t) | | constant | miaobyte | constant(value)->T1 | constant(var value)->(tensor t) | @@ -50,18 +51,19 @@ |-----------|--------|--------------|----------------| | switch | miaobyte | C=switch(tensors,cases) | switch(listtensor tensors, tensor cases)->(tensor result) | | greaterscalar | miaobyte | mask=compare(T1, scalar) | greaterscalar(tensor A, var scalar)->(tensor mask) | -| equalscalar | miaobyte | mask=compare(T1, scalar) | equalscalar(tensor A, var scalar, var epsilon)->(tensor mask) | +| notequal | miaobyte | T1!=T2->mask | notequal(tensor A, tensor B, var epsilon)->(tensor mask) | +| equalscalar | miaobyte | T1==scalar->mask | equalscalar(tensor A, var scalar, var epsilon)->(tensor mask) | | min | miaobyte | T3=min(T1, T2) | min(tensor A, tensor B)->(tensor C) | | maxscalar | miaobyte | T3=max(T1, scalar) | maxscalar(tensor A, var scalar)->(tensor C) | | tan | miaobyte | T3=tan(T1) | tan(tensor A)->(tensor C) | | sin | miaobyte | T3=sin(T1) | sin(tensor A)->(tensor C) | -| dropout | miaobyte | dropout(p,seed)->A | dropout(var p, var seed)->(tensor A) | | divscalar | miaobyte | T3=scalar/T1 | divscalar(tensor A, var scalar)->(tensor C) | | log | miaobyte | T3=log(T1) | log(tensor A)->(tensor C) | | addscalar | miaobyte | T3=T1+scalar | addscalar(tensor A, var b)->(tensor C) | | greater | miaobyte | mask=compare(T1, T2) | greater(tensor A, tensor B)->(tensor mask) | | lessscalar | miaobyte | mask=compare(T1, scalar) | lessscalar(tensor A, var scalar)->(tensor mask) | | cos | miaobyte | T3=cos(T1) | cos(tensor A)->(tensor C) | +| notequalscalar | miaobyte | T1!=scalar->mask | notequalscalar(tensor A, var scalar, var epsilon)->(tensor mask) | | minscalar | miaobyte | T3=min(T1, scalar) | minscalar(tensor A, var scalar)->(tensor C) | | rpowscalar | miaobyte | T3=pow(scalar, T1) | rpowscalar(var scalar, tensor A)->(tensor C) | | rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var scalar, tensor A)->(tensor C) | @@ -75,7 +77,7 @@ | subscalar | miaobyte | T3=T1-scalar | subscalar(tensor A, var b)->(tensor C) | | exp | miaobyte | T3=exp(T1) | exp(tensor A)->(tensor C) | | mul | miaobyte | T3=T1*T2 | mul(tensor A, tensor B)->(tensor C) | -| equal | miaobyte | mask=compare(T1, T2) | equal(tensor A, tensor B, var epsilon)->(tensor mask) | +| equal | miaobyte | T1==T2->mask | equal(tensor A, tensor B, var epsilon)->(tensor mask) | | mulscalar | miaobyte | T3=T1*scalar | mulscalar(tensor A, var b)->(tensor C) | | div | miaobyte | T3=T1/T2 | div(tensor A, tensor B)->(tensor C) | | invert | miaobyte | T3=~T1 | invert(tensor A)->(tensor C) | diff --git a/doc/excuter/op-mem-ompsimd/list.md b/doc/excuter/op-mem-ompsimd/list.md index 35a3a99e..d2ba1745 100644 --- a/doc/excuter/op-mem-ompsimd/list.md +++ b/doc/excuter/op-mem-ompsimd/list.md @@ -41,6 +41,7 @@ | Operation | Author | Math Formula | IR Instruction | |-----------|--------|--------------|----------------| | normal | miaobyte | normal(mean,stddev,seed)->T1 | normal(var mean, var std, var seed)->(tensor t) | +| dropout | miaobyte | dropout(p,seed)->A | dropout(var p, var seed)->(tensor A) | | uniform | miaobyte | uniform(low,high,seed)->T1 | uniform(var low, var high, var seed)->(tensor t) | | arange | miaobyte | arange(start,step)->T1 | arange(var start, var step)->(tensor t) | | constant | miaobyte | constant(value)->T1 | constant(var value)->(tensor t) | @@ -51,15 +52,16 @@ |-----------|--------|--------------|----------------| | switch | miaobyte | C=switch([tensors],case) | switch(listtensor tensors, tensor cases)->(tensor C) | | greaterscalar | miaobyte | mask=greater(T1,scalar) | greaterscalar(tensor A, var scalar)->(tensor mask) | -| equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor A, var scalar)->(tensor mask) | +| notequal | miaobyte | notequal(T1,T2)->mask | notequal(tensor A, tensor B, var epsilon)->(tensor mask) | +| equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor A, var scalar, var eposilon)->(tensor mask) | | min | miaobyte | T3=min(T1,T2) | min(tensor A, tensor B)->(tensor C) | | maxscalar | miaobyte | T3=max(T1,scalar) | maxscalar(tensor A, var scalar)->(tensor C) | -| dropout | miaobyte | dropout(p,seed)->A | dropout(var p, var seed)->(tensor A) | | divscalar | miaobyte | T3=T1/scalar | divscalar(tensor A, var scalar)->(tensor C) | | log | miaobyte | T3=log(T1) | log(tensor A)->(tensor C) | | addscalar | miaobyte | T3=T1+scalar | addscalar(tensor a, var scalar)->(tensor c) | | greater | miaobyte | mask=greater(T1,T2) | greater(tensor A, tensor B)->(tensor mask) | | lessscalar | miaobyte | mask=less(T1,scalar) | lessscalar(tensor A, var scalar)->(tensor mask) | +| notequalscalar | miaobyte | mask=notequal(T1,scalar) | notequalscalar(tensor A, var scalar, var epsilon)->(tensor mask) | | minscalar | miaobyte | T3=min(T1,scalar) | minscalar(tensor A, var scalar)->(tensor C) | | rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var scalar, tensor A)->(tensor C) | | rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var scalar, tensor A)->(tensor C) | @@ -73,7 +75,7 @@ | subscalar | miaobyte | T3=T1-scalar | subscalar(tensor a, var scalar)->(tensor c) | | exp | miaobyte | T3=exp(T1) | exp(tensor A)->(tensor C) | | mul | miaobyte | T3=T1*T2 | mul(tensor A, tensor B)->(tensor C) | -| equal | miaobyte | mask=equal(T1,T2) | equal(tensor A, tensor B)->(tensor mask) | +| equal | miaobyte | equal(T1,T2)->mask | equal(tensor A, tensor B, var eposilon)->(tensor mask) | | mulscalar | miaobyte | T3=T1*scalar | mulscalar(tensor A, var b)->(tensor C) | | div | miaobyte | T3=T1/T2 | div(tensor A, tensor B)->(tensor C) | | invert | miaobyte | T3=~T1 | invert(tensor A)->(tensor C) | diff --git a/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp b/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp index eb63f674..a33ca60c 100644 --- a/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp +++ b/excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp @@ -333,6 +333,31 @@ namespace deepx::tensorfunc { equalscalarDispatcher::equalscalar(A, scalar, epsilon, mask); } + //notequal(A,B)=>mask + template + struct notequalDispatcher + { + static void notequal(const Tensor &A, const Tensor &B,const float epsilon, Tensor &mask) = delete; + }; + + template + void notequal(const Tensor &A, const Tensor &B,const float epsilon, Tensor &mask) + { + notequalDispatcher::notequal(A, B, epsilon, mask); + } + + // notequal(A,scalar)=>mask + template + struct notequalscalarDispatcher + { + static void notequalscalar(const Tensor &A, const T scalar,const float epsilon, Tensor &mask) = delete; + }; + + template + void notequalscalar(const Tensor &A, const T scalar,const float epsilon, Tensor &mask) + { + notequalscalarDispatcher::notequalscalar(A, scalar, epsilon, mask); + } // less(A,B)=>mask template diff --git a/excuter/op-mem-cuda/src/client/tfs.cpp b/excuter/op-mem-cuda/src/client/tfs.cpp index 0c0ca44e..fa6cdcc6 100644 --- a/excuter/op-mem-cuda/src/client/tfs.cpp +++ b/excuter/op-mem-cuda/src/client/tfs.cpp @@ -406,7 +406,7 @@ namespace deepx::tf { Param("A", DataCategory::Tensor, Precision::Any), Param("B", DataCategory::Tensor, Precision::Any), - Param("epsilon", DataCategory::Var, Precision::Float64), + Param("epsilon", DataCategory::Var, Precision::Float32), }), vector( { @@ -416,7 +416,27 @@ namespace deepx::tf { Param("A", DataCategory::Tensor, Precision::Any), Param("scalar", DataCategory::Var, Precision::Any), - Param("epsilon", DataCategory::Var, Precision::Float64), + Param("epsilon", DataCategory::Var, Precision::Float32), + }), + vector( + { + Param("mask", DataCategory::Tensor, Precision::Bool), + }))); + tffactory.add_tf(std::make_shared>(vector( + { + Param("A", DataCategory::Tensor, Precision::Any), + Param("B", DataCategory::Tensor, Precision::Any), + Param("epsilon", DataCategory::Var, Precision::Float32), + }), + vector( + { + Param("mask", DataCategory::Tensor, Precision::Bool), + }))); + tffactory.add_tf(std::make_shared>(vector( + { + Param("A", DataCategory::Tensor, Precision::Any), + Param("scalar", DataCategory::Var, Precision::Any), + Param("epsilon", DataCategory::Var, Precision::Float32), }), vector( { diff --git a/excuter/op-mem-cuda/src/deepx/mem/mem_cuda.hpp b/excuter/op-mem-cuda/src/deepx/mem/mem_cuda.hpp index 556a4aac..384eda3a 100644 --- a/excuter/op-mem-cuda/src/deepx/mem/mem_cuda.hpp +++ b/excuter/op-mem-cuda/src/deepx/mem/mem_cuda.hpp @@ -107,7 +107,12 @@ namespace deepx::mem result->data = ptr_tensor->data; break; } - + case Precision::Bool: + { + auto ptr_tensor = std::static_pointer_cast>(ptr); + result->data = ptr_tensor->data; + break; + } default: throw std::runtime_error("Unsupported dtype: " + precision_str(ptr->shape.dtype)); } diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu index f5e93fc5..72c86352 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu @@ -250,6 +250,122 @@ namespace deepx::tensorfunc template void launch_equalscalar(const int16_t *A, const int16_t scalar, const float epsilon, bool *mask, const int size); template void launch_equalscalar(const int8_t *A, const int8_t scalar, const float epsilon, bool *mask, const int size); + // not equal + template + __global__ void notequalwithepsilon_kernel(const T *A, const T *B, const float epsilon, MaskT *mask, const int size) + { + int stride = blockDim.x * gridDim.x; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride) + { + float diff = fabsf(static_cast(A[idx]) - static_cast(B[idx])); + if (diff < epsilon) + { + mask[idx] = 0; + } + else + { + mask[idx] = 1; + } + } + } + + template + __global__ void notequal_kernel(const T *A, const T *B, MaskT *mask, const int size) + { + int stride = blockDim.x * gridDim.x; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride) + { + mask[idx] = (A[idx] != B[idx]); + } + } + + template + void launch_notequal(const T *A, const T *B, const float epsilon, MaskT *mask, const int size) + { + auto [numBlocks, blockSize] = BestDims(size); + if (epsilon == 0) + { + notequal_kernel<<>>(A, B, mask, size); + } + else + { + notequalwithepsilon_kernel<<>>(A, B, epsilon, mask, size); + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + throw std::runtime_error("Failed to launch add kernel: " + + std::string(cudaGetErrorString(err))); + } + } + + template void launch_notequal(const double *A, const double *B, const float epsilon, bool *mask, const int size); + template void launch_notequal(const float *A, const float *B, const float epsilon, bool *mask, const int size); + template void launch_notequal(const nv_bfloat16 *A, const nv_bfloat16 *B, const float epsilon, bool *mask, const int size); + template void launch_notequal<__half,bool>(const __half *A, const __half *B, const float epsilon, bool *mask, const int size); + template void launch_notequal(const int64_t *A, const int64_t *B, const float epsilon, bool *mask, const int size); + template void launch_notequal(const int32_t *A, const int32_t *B, const float epsilon, bool *mask, const int size); + template void launch_notequal(const int16_t *A, const int16_t *B, const float epsilon, bool *mask, const int size); + template void launch_notequal(const int8_t *A, const int8_t *B, const float epsilon, bool *mask, const int size); + + // notequalscalar + template + __global__ void notequalscalarwithepsilon_kernel(const T *A, const T scalar, const float epsilon, MaskT *mask, const int size) + { + int stride = blockDim.x * gridDim.x; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride) + { + float diff = fabsf(static_cast(A[idx]) - static_cast(scalar)); + if (diff < epsilon) + { + mask[idx] = 0; + } + else + { + mask[idx] = 1; + } + } + } + + template + __global__ void notequalscalar_kernel(const T *A, const T scalar, MaskT *mask, const int size) + { + int stride = blockDim.x * gridDim.x; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride) + { + mask[idx] = (A[idx] != scalar); + } + } + + template + void launch_notequalscalar(const T *A, const T scalar, const float epsilon, MaskT *mask, const int size) + { + auto [numBlocks, blockSize] = BestDims(size); + if (epsilon == 0) + { + notequalscalar_kernel<<>>(A, scalar, mask, size); + } + else + { + notequalscalarwithepsilon_kernel<<>>(A, scalar, epsilon, mask, size); + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + throw std::runtime_error("Failed to launch add kernel: " + + std::string(cudaGetErrorString(err))); + } + } + + template void launch_notequalscalar(const double *A, const double scalar, const float epsilon, bool *mask, const int size); + template void launch_notequalscalar(const float *A, const float scalar, const float epsilon, bool *mask, const int size); + template void launch_notequalscalar(const nv_bfloat16 *A, const nv_bfloat16 scalar, const float epsilon, bool *mask, const int size); + template void launch_notequalscalar<__half,bool>(const __half *A, const __half scalar, const float epsilon, bool *mask, const int size); + template void launch_notequalscalar(const int64_t *A, const int64_t scalar, const float epsilon, bool *mask, const int size); + template void launch_notequalscalar(const int32_t *A, const int32_t scalar, const float epsilon, bool *mask, const int size); + template void launch_notequalscalar(const int16_t *A, const int16_t scalar, const float epsilon, bool *mask, const int size); + template void launch_notequalscalar(const int8_t *A, const int8_t scalar, const float epsilon, bool *mask, const int size); + // less template __global__ void less_kernel(const T *A, const T *B, MaskT *mask, const int size) diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cuh b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cuh index ee9ea259..c813acb2 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cuh +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cuh @@ -39,7 +39,7 @@ namespace deepx::tensorfunc //equal template - __global__ void equal_kernel(const T* A, const T* B,const float epsilon, MaskT* mask, const int size); + __global__ void equalwithepsilon_kernel(const T* A, const T* B,const float epsilon, MaskT* mask, const int size); template __global__ void equal_kernel(const T* A, const T* B, float* mask, const int size); @@ -49,11 +49,34 @@ namespace deepx::tensorfunc //equalscalar template - __global__ void equalscalar_kernel(const T* A, const T scalar,const float epsilon, MaskT* mask, const int size); + __global__ void equalscalarwithepsilon_kernel(const T* A, const T scalar,const float epsilon, MaskT* mask, const int size); + + template + __global__ void equalscalar_kernel(const T* A, const T scalar, float* mask, const int size); template void launch_equalscalar(const T* A, const T scalar,const float epsilon, MaskT* mask, const int size); + //notequal + template + __global__ void notequalwithepsilon_kernel(const T* A, const T* B,const float epsilon, MaskT* mask, const int size); + + template + __global__ void notequal_kernel(const T* A, const T* B, MaskT* mask, const int size); + + template + void launch_notequal(const T* A, const T* B,const float epsilon, MaskT* mask, const int size); + + //notequalscalar + template + __global__ void notequalscalarwithepsilon_kernel(const T* A, const T scalar,const float epsilon, MaskT* mask, const int size); + + template + __global__ void notequalscalar_kernel(const T* A, const T scalar, MaskT* mask, const int size); + + template + void launch_notequalscalar(const T* A, const T scalar,const float epsilon, MaskT* mask, const int size); + //less template __global__ void less_kernel(const T* A, const T* B, MaskT* mask, const int size); diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.hpp b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.hpp index ed58ac6a..060feb66 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.hpp +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.hpp @@ -98,6 +98,41 @@ namespace deepx::tensorfunc launch_equalscalar(A.data, scalar, epsilon, mask.data, A.shape.size); } }; + // notequal(A,B)=>C + template + struct notequalDispatcher + { + static void notequal(const Tensor &A, const Tensor &B,const float epsilon, Tensor &mask) + { + if (A.shape.size != B.shape.size || A.shape.size != mask.shape.size) + { + throw TensorShapeError("notequal"); + } + if (epsilon < 0) + { + throw std::invalid_argument("notequal epsilon must be positive"); + } + launch_notequal(A.data, B.data, epsilon, mask.data, A.shape.size); + } + }; + // notequalscalar(A,scalar)=>C + template + struct notequalscalarDispatcher + { + static void notequalscalar(const Tensor &A, const T scalar,const float epsilon, Tensor &mask) + { + if (A.shape.size != mask.shape.size) + { + throw TensorShapeError("notequalscalar"); + } + if (epsilon < 0) + { + throw std::invalid_argument("notequal epsilon must be positive"); + } + launch_notequalscalar(A.data, scalar, epsilon, mask.data, A.shape.size); + } + }; + // less(A,B)=>C template diff --git a/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp b/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp index fe3734a1..cce24d81 100644 --- a/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp +++ b/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp @@ -284,7 +284,7 @@ namespace deepx::tf string math_formula() const override { - return "mask=compare(T1, T2)"; + return "T1==T2->mask"; } shared_ptr clone() const override { @@ -295,7 +295,10 @@ namespace deepx::tf { Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; Precision b_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; - float epsilon = this->getvar(2, mem); + float epsilon =0; + if (this->args.size()>2){ + epsilon=this->getvar(2,mem,true); + } Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; if (a_type != b_type || mask_type != Precision::Bool) { @@ -351,7 +354,7 @@ namespace deepx::tf string math_formula() const override { - return "mask=compare(T1, scalar)"; + return "T1==scalar->mask"; } shared_ptr clone() const override { @@ -363,7 +366,7 @@ namespace deepx::tf Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; float epsilon = this->getvar(2, mem); - if (a_type != mask_type || mask_type != Precision::Bool) + if ( mask_type != Precision::Bool) { error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(mask_type); return 1; @@ -402,6 +405,143 @@ namespace deepx::tf } }; + + template + class NotEqual : public TF + { + public: + NotEqual(const vector &args, const vector &returns) + { + this->name = "notequal"; + this->metadata.author=Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + + string math_formula() const override + { + return "T1!=T2->mask"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + + int run(shared_ptr mem, string &error) override + { + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision b_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; + float epsilon =0; + if (this->args.size()>2){ + epsilon=this->getvar(2,mem,true); + } + Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (a_type != b_type || mask_type != Precision::Bool) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(b_type) + " or " + precision_str(a_type) + " != " + precision_str(mask_type); + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float16: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::BFloat16: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported type: " + precision_str(a_type); + return 1; + } + return 0; + } + }; + + template + class NotEqualScalar : public TF + { + public: + NotEqualScalar(const vector &args, const vector &returns) + { + this->name = "notequalscalar"; + this->metadata.author=Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + + string math_formula() const override + { + return "T1!=scalar->mask"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + + int run(shared_ptr mem, string &error) override + { + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + float epsilon = this->getvar(2, mem); + if (mask_type != Precision::Bool) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(mask_type); + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float16: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::BFloat16: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported type: " + precision_str(a_type); + return 1; + } + return 0; + } + }; + // less template class Less : public TF diff --git a/excuter/op-mem-ompsimd/src/client/tfs.cpp b/excuter/op-mem-ompsimd/src/client/tfs.cpp index fcb2b1c8..af50289a 100644 --- a/excuter/op-mem-ompsimd/src/client/tfs.cpp +++ b/excuter/op-mem-ompsimd/src/client/tfs.cpp @@ -409,6 +409,7 @@ namespace deepx::tf { Param("A", DataCategory::Tensor, Precision::Any), Param("B", DataCategory::Tensor, Precision::Any), + Param("eposilon", DataCategory::Var, Precision::Float32), }), vector( { @@ -419,6 +420,29 @@ namespace deepx::tf { Param("A", DataCategory::Tensor, Precision::Any), Param("scalar", DataCategory::Var, Precision::Any), + Param("eposilon", DataCategory::Var, Precision::Float32), + }), + vector( + { + Param("mask", DataCategory::Tensor, Precision::Bool), + }))); + // notequal author=miaobyte + tffactory.add_tf(std::make_shared>(vector( + { + Param("A", DataCategory::Tensor, Precision::Any), + Param("B", DataCategory::Tensor, Precision::Any), + Param("epsilon", DataCategory::Var, Precision::Float32), + }), + vector( + { + Param("mask", DataCategory::Tensor, Precision::Bool), + }))); + // notequal scalar author=miaobyte + tffactory.add_tf(std::make_shared>(vector( + { + Param("A", DataCategory::Tensor, Precision::Any), + Param("scalar", DataCategory::Var, Precision::Any), + Param("epsilon", DataCategory::Var, Precision::Float32), }), vector( { diff --git a/excuter/op-mem-ompsimd/src/deepx/mem/mem_ompsimd.hpp b/excuter/op-mem-ompsimd/src/deepx/mem/mem_ompsimd.hpp index 8e4710b0..c58c4395 100644 --- a/excuter/op-mem-ompsimd/src/deepx/mem/mem_ompsimd.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/mem/mem_ompsimd.hpp @@ -92,7 +92,12 @@ namespace deepx::mem result->data = ptr_tensor->data; break; } - + case Precision::Bool: + { + auto ptr_tensor = std::static_pointer_cast>(ptr); + result->data = ptr_tensor->data; + break; + } default: throw std::runtime_error("Unsupported dtype: " + precision_str(ptr->shape.dtype)); } diff --git a/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp b/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp index de9e70a6..9be9f3a4 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp @@ -829,6 +829,71 @@ namespace deepx::tensorfunc } }; }; + // notequal + template + struct notequalDispatcher + { + static void notequal(const Tensor &A, const Tensor &B, const float epsilon, Tensor &mask) + { + if (A.shape == B.shape && mask.shape == A.shape) + { + A.shape.rangeElementwiseParallel([&A, &B, &mask, epsilon](int i, int i_end) + { + if (epsilon == 0) + { + for (int j = 0; j < i_end; j++) + { + + mask.data[i + j] = A.data[i + j] != B.data[i + j]; + } + } + else + { + for (int j = 0; j < i_end; j++) + { + mask.data[i + j] = std::abs(A.data[i + j] - B.data[i + j]) > epsilon; + }; + } }); + } + else + { + throw std::invalid_argument("shape mismatch"); + } + } + }; + + // notequalscalar + template + struct notequalscalarDispatcher + { + static void notequalscalar(const Tensor &A, const T scalar, const float epsilon, Tensor &mask) + { + if (A.shape == mask.shape) + { + A.shape.rangeElementwiseParallel([&A, &mask, &scalar, epsilon](int i, int i_end) + { + if (epsilon == 0) + { + for (int j = 0; j < i_end; j++) + { + + mask.data[i + j] = A.data[i + j] != scalar; + } + } + else + { + for (int j = 0; j < i_end; j++) + { + mask.data[i + j] = std::abs(A.data[i + j] - scalar) > epsilon; + } + } }); + } + else + { + throw std::invalid_argument("shape mismatch"); + } + }; + }; // less template @@ -940,7 +1005,5 @@ namespace deepx::tensorfunc } } }; - - }; #endif // DEEPX_OP_CPU_ELEMENTWISE_HPP \ No newline at end of file diff --git a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp index bf579a7a..e0f00c9f 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp @@ -1504,7 +1504,7 @@ namespace deepx::tf } string math_formula() const override { - return "mask=equal(T1,T2)"; + return "equal(T1,T2)->mask"; } shared_ptr clone() const override { @@ -1514,7 +1514,10 @@ namespace deepx::tf { Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; Precision b_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; - float epsilon = this->getvar(2,mem,true); + float epsilon =0; + if (this->args.size()>2){ + epsilon=this->getvar(2,mem,true); + } Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; if (a_type != b_type || mask_type!=Precision::Bool) { @@ -1575,9 +1578,9 @@ namespace deepx::tf Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; float epsilon = this->getvar(2,mem,true); - if (a_type != mask_type) + if (mask_type !=Precision::Bool) { - error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(mask_type); + error = "Type mismatch: " + precision_str(mask_type)+"!=bool"; return 1; } switch (a_type) @@ -1608,6 +1611,128 @@ namespace deepx::tf } }; + //notequal + template + class NotEqual : public TF + { + public: + NotEqual(vector args, vector returns) + { + this->name = "notequal"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override + { + return "notequal(T1,T2)->mask"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + int run(shared_ptr mem, string &error) override + { + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision b_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; + float epsilon =0; + if (this->args.size()>2){ + epsilon=this->getvar(2,mem,true); + } + Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (a_type != b_type || mask_type!=Precision::Bool) + { + error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(b_type) + " " + precision_str(mask_type)+"!=bool"; + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::notequal(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; + } + return 0; + } + }; + + + template + class NotEqualScalar : public TF + { + public: + NotEqualScalar(vector args, vector returns) + { + this->name = "notequalscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override + { + return "mask=notequal(T1,scalar)"; + } + shared_ptr clone() const override + { + return make_shared>(*this); + } + int run(shared_ptr mem, string &error) override + { + Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + float epsilon = this->getvar(2,mem,true); + if (mask_type !=Precision::Bool) + { + error = "Type mismatch: " + precision_str(mask_type)+"!=bool"; + return 1; + } + switch (a_type) + { + case Precision::Float64: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem,true), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Float32: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem,true), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int64: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem,true), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int32: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem,true), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int16: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem,true), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + case Precision::Int8: + tensorfunc::notequalscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem,true), epsilon, *mem->gettensor(this->returns[0].textvalue)); + break; + default: + error = "Unsupported dtype: " + precision_str(a_type); + return 1; + } + return 0; + } + }; + //less template class Less : public TF diff --git a/front/py/deepx/nn/functional/__init__.py b/front/py/deepx/nn/functional/__init__.py index 051f8d68..2c5d517b 100644 --- a/front/py/deepx/nn/functional/__init__.py +++ b/front/py/deepx/nn/functional/__init__.py @@ -19,8 +19,15 @@ #leaffunc "newtensor","rnewtensor","printtensor","load", #life "printtensor","save",#io - "constant","constant_","full","zeros","ones","uniform","uniform_","arange","arange_","kaiming_uniform","kaiming_uniform_", - "add","sub","mul","div","sqrt","pow","exp","log","invert","todtype","dropout", + "constant","constant_","dropout","full","zeros","ones","uniform","uniform_","arange","arange_", + "kaiming_uniform","kaiming_uniform_", + "add","sub","mul","div", + "sqrt","pow","exp","log", + "min","max", + "less","greater","equal","notequal", + "switch", + "todtype", + "invert", "matmul", "reducemax","reducemin","sum","prod", "reshape","permute","transpose","concat","broadcastTo","indexselect", diff --git a/front/py/deepx/nn/functional/authormap.py b/front/py/deepx/nn/functional/authormap.py index 5c6b492b..89142df2 100644 --- a/front/py/deepx/nn/functional/authormap.py +++ b/front/py/deepx/nn/functional/authormap.py @@ -17,18 +17,30 @@ 'divscalar':'miaobyte', 'rdiv':'miaobyte', 'rdivscalar':'miaobyte', + # 'invert':'miaobyte', - 'compare':'miaobyte', + # 'min':'miaobyte', 'minscalar':'miaobyte', 'max':'miaobyte', 'maxscalar':'miaobyte', + # + 'less': 'miaobyte', + 'lessscalar': 'miaobyte', + 'greater': 'miaobyte', + 'greaterscalar': 'miaobyte', + 'equal': 'miaobyte', + 'equalscalar': 'miaobyte', + 'notequal': 'miaobyte', + 'notequalscalar': 'miaobyte', + # 'exp':'miaobyte', 'log':'miaobyte', 'pow':'miaobyte', 'powscalar':'miaobyte', 'rpowscalar':'miaobyte', 'sqrt':'miaobyte', + # 'dropout':'miaobyte', #changeshape 'reshape':'miaobyte', diff --git a/front/py/deepx/nn/functional/elementwise.py b/front/py/deepx/nn/functional/elementwise.py index 5ab5d7b9..9b502473 100644 --- a/front/py/deepx/nn/functional/elementwise.py +++ b/front/py/deepx/nn/functional/elementwise.py @@ -1,7 +1,56 @@ from deepx.tensor import Tensor from deepx.nn.functional import newtensor +# 幂运算 def rsqrt(input:Tensor)->Tensor: from .leaffunc_elementwise import sqrt return 1/sqrt(input) - \ No newline at end of file + +# 比较 +def clamp(input:Tensor,min:float,max:float)->Tensor: + from .leaffunc_elementwise import max,min + return max(min(input,max),min) + +# 类型转换 +def double(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='float64',name=input.name) + return todtype(input,dest) + +def float(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='float32',name=input.name) + return todtype(input,dest) + +def float16(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='float16',name=input.name) + return todtype(input,dest) +def bfloat16(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='bfloat16',name=input.name) + return todtype(input,dest) + +def int64(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='int64',name=input.name) + return todtype(input,dest) +def int32(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='int32',name=input.name) + return todtype(input,dest) + +def int16(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='int16',name=input.name) + return todtype(input,dest) + +def int8(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='int8',name=input.name) + return todtype(input,dest) + +def bool(input:Tensor)->Tensor: + from .leaffunc_elementwise import todtype + dest=newtensor(input.shape,dtype='bool',name=input.name) + return todtype(input,dest) diff --git a/front/py/deepx/nn/functional/leaffunc.py b/front/py/deepx/nn/functional/leaffunc.py index 89cbde68..d846f9ea 100644 --- a/front/py/deepx/nn/functional/leaffunc.py +++ b/front/py/deepx/nn/functional/leaffunc.py @@ -1,7 +1,6 @@ from typing import Union import importlib - from deepx.tensor import Tensor,Shape from .leaffunc_life import newtensor from .authormap import defaultauthor @@ -9,13 +8,15 @@ # inplace操作的函数,其名为_后缀, 返回值为空 # 非inplace操作的函数,其名为_后缀, 返回值为Tensor - -def create_A_B_tf_C(op_name): +def create_A_B_tf_C(op_name,outtype=None): """创建元素级操作函数""" def op_func( a: Tensor, b: Union[Tensor, float, int] = None, out: Union[Tensor, str] = None) -> Tensor: + outtype=a.dtype + if op_func.__outtype__ is not None: + outtype=op_func.__outtype__ outtensor = out rtf_module = importlib.import_module('deepx.nn.functional.rtf_elementwise') if isinstance(b, Tensor): @@ -26,20 +27,58 @@ def op_func( an = a.broadcastTo(newshape) bn = b.broadcastTo(newshape) if isinstance(out,str) or out is None: - outtensor=newtensor(newshape,dtype=a.dtype,name=out) + outtensor=newtensor(newshape,dtype=outtype,name=out) else: if isinstance(out,str) or out is None: - outtensor=newtensor(a.shape,dtype=a.dtype,name=out) + outtensor=newtensor(a.shape,dtype=outtype,name=out) rtf_func = getattr(rtf_module, f'rtf_{op_name}') rtf_func(an, bn, outtensor, defaultauthor[op_name]) else: if isinstance(out,str) or out is None: - outtensor=newtensor(a.shape,dtype=a.dtype,name=out) + outtensor=newtensor(a.shape,dtype=outtype,name=out) rtf_func = getattr(rtf_module, f'rtf_{op_name}scalar') rtf_func(a, b, outtensor, defaultauthor[f'{op_name}scalar']) return outtensor op_func.__name__ = op_name op_func.__qualname__ = op_name + op_func.__outtype__ = outtype + return op_func + +def create_A_B_c_tf_D(op_name,outtype=None): + """创建元素级操作函数""" + def op_func( + A: Tensor, + B: Union[Tensor, float, int] = None, + c: float=0, + out: Union[Tensor, str] = None) -> Tensor: + outtype='bool' + if op_func.__outtype__ is not None: + outtype=op_func.__outtype__ + outtensor = out + rtf_module = importlib.import_module('deepx.nn.functional.rtf_elementwise') + if isinstance(B, Tensor): + an=A + bn=B + if A.shape != B.shape: + newshape = Shape.broadcast_shape(A.shape, B.shape) + an = A.broadcastTo(newshape) + bn = B.broadcastTo(newshape) + if isinstance(out,str) or out is None: + outtensor=newtensor(newshape,dtype=outtype,name=out) + else: + if isinstance(out,str) or out is None: + outtensor=newtensor(A.shape,dtype=outtype,name=out) + rtf_func = getattr(rtf_module, f'rtf_{op_name}') + rtf_func(an, bn,c, outtensor, defaultauthor[op_name]) + else: + if isinstance(out,str) or out is None: + outtensor=newtensor(A.shape,dtype=outtype,name=out) + rtf_func = getattr(rtf_module, f'rtf_{op_name}scalar') + rtf_func(A,B,c, outtensor, defaultauthor[f'{op_name}scalar']) + return outtensor + op_func.__name__ = op_name + op_func.__qualname__ = op_name + op_func.__outtype__ = outtype return op_func def create_A_tf_C(op_name): diff --git a/front/py/deepx/nn/functional/leaffunc_elementwise.py b/front/py/deepx/nn/functional/leaffunc_elementwise.py index 6475105a..4a6e26fd 100644 --- a/front/py/deepx/nn/functional/leaffunc_elementwise.py +++ b/front/py/deepx/nn/functional/leaffunc_elementwise.py @@ -1,11 +1,11 @@ from typing import Optional, Union from deepx import Tensor,Shape,Number -from .leaffunc import create_A_B_tf_C,create_A_tf_C +from .leaffunc import create_A_B_tf_C,create_A_tf_C,create_A_B_c_tf_D from .leaffunc_life import newtensor from .authormap import defaultauthor -# 创建具体操作函数 +#四则运算 add = create_A_B_tf_C('add') sub = create_A_B_tf_C('sub') mul = create_A_B_tf_C('mul') @@ -19,10 +19,6 @@ def div( return _div(a,b,out) elif isinstance(a,float) or isinstance(a,int): return rdiv(a,b,out) - else: - raise ValueError(f"Invalid type for a: {type(a)}") - -#div def rdiv( a: Union[float, int], b: Tensor, @@ -33,11 +29,9 @@ def rdiv( from .rtf_elementwise import rtf_rdivscalar rtf_rdivscalar(a,b,outtensor,defaultauthor['rdivscalar']) return outtensor - -max=create_A_B_tf_C('max') -min=create_A_B_tf_C('min') -#pow +## 幂、指数 运算 + pow=create_A_B_tf_C('pow') def rpow(a:Number,b:Tensor,out:Union[Tensor,str]=None)->Tensor: outtensor=out @@ -46,15 +40,42 @@ def rpow(a:Number,b:Tensor,out:Union[Tensor,str]=None)->Tensor: from .rtf_elementwise import rtf_rpowscalar rtf_rpowscalar(a,b,outtensor,defaultauthor['rpowscalar']) return outtensor -#sqrt - sqrt=create_A_tf_C('sqrt') exp=create_A_tf_C('exp') log=create_A_tf_C('log') -#invert +# 三角函数 +sin=create_A_tf_C('sin') +cos=create_A_tf_C('cos') +tan=create_A_tf_C('tan') + +#取大小值 +max=create_A_B_tf_C('max') +min=create_A_B_tf_C('min') + +#位运算 invert=create_A_tf_C('invert') +#比较 +less=create_A_B_tf_C('less',outtype='bool') +greater=create_A_B_tf_C('greater',outtype='bool') +equal=create_A_B_c_tf_D('equal',outtype='bool') +notequal=create_A_B_c_tf_D('notequal',outtype='bool') + +#分支 +def switch(X:tuple[Tensor,...], cases:Tensor, out:Union[Tensor,str]=None)->Tensor: + assert isinstance(X,tuple) + for x in X: + assert isinstance(x,Tensor) and x.shape==cases.shape + outtensor=out + if isinstance(out,str) or out is None: + outtensor=newtensor(cases.shape,dtype=cases.dtype,name=out) + assert isinstance(outtensor,Tensor) and outtensor.shape==cases.shape + + from .rtf_elementwise import rtf_switch + rtf_switch(X,cases,outtensor,defaultauthor['switch']) + return outtensor + #todtype def todtype(t:Tensor,dest:Tensor): assert isinstance(t,Tensor) diff --git a/front/py/deepx/nn/functional/leaffunc_init.py b/front/py/deepx/nn/functional/leaffunc_init.py index 6ff30350..08189a24 100644 --- a/front/py/deepx/nn/functional/leaffunc_init.py +++ b/front/py/deepx/nn/functional/leaffunc_init.py @@ -40,7 +40,7 @@ def dropout(a:Tensor, p:float=0.5, seed:int=None)->Tensor: # 初始化 def arange(start:Number,end:Number,step:Number=1,dtype:str='float32',name:str=None)->Tensor: s =[int((end-start)/step)] - outtensor=newtensor(s,dtype=dtype,name=name) + outtensor=newtensor(tuple(s),dtype=dtype,name=name) arange_(outtensor,start,step) return outtensor diff --git a/front/py/deepx/nn/functional/rtf.py b/front/py/deepx/nn/functional/rtf.py index d3c46d39..acd99249 100644 --- a/front/py/deepx/nn/functional/rtf.py +++ b/front/py/deepx/nn/functional/rtf.py @@ -8,6 +8,17 @@ def A_B_op_C(op:str,a:Tensor,b:Tensor,out:Tensor,author='miaobyte'): ir=DeepxIR(op, args, returns,author) send(ir) +def A_B_c_op_D(op:str,a:Tensor,b:Tensor,c:Union[float,int],out:Tensor,author='miaobyte'): + args=[Param.tensor(a),Param.tensor(b),Param.varnum(c)] + returns=[Param.tensor(out)] + ir=DeepxIR(op, args, returns,author) + send(ir) +def A_scalar_c_op_D(op:str,a:Tensor,scalar:Union[float,int],c:Union[float,int],out:Tensor,author='miaobyte'): + args=[Param.tensor(a),Param.varnum(scalar),Param.varnum(c)] + returns=[Param.tensor(out)] + ir=DeepxIR(op, args, returns,author) + send(ir) + def A_scalar_op(op:str,a:Tensor,b:Union[float,int],author='miaobyte'): args=[Param.tensor(a),Param.varnum(b)] returns=[] diff --git a/front/py/deepx/nn/functional/rtf_elementwise.py b/front/py/deepx/nn/functional/rtf_elementwise.py index 4e20429d..34ab28e9 100644 --- a/front/py/deepx/nn/functional/rtf_elementwise.py +++ b/front/py/deepx/nn/functional/rtf_elementwise.py @@ -1,7 +1,7 @@ from deepx.tensor import Tensor,Number from deepx.nn.deepxir import DeepxIR,Param from deepx.scheduler import send -from .rtf import A_B_op_C,A_scalar_op_C,A_op_C +from .rtf import A_B_op_C,A_B_c_op_D,A_scalar_op_C,A_scalar_c_op_D,A_op_C # 四则运算 def rtf_add(a:Tensor, b:Tensor, out:Tensor, author='miaobyte')->Tensor: @@ -71,10 +71,6 @@ def rtf_log(a:Tensor, out:Tensor, author='miaobyte')->Tensor: A_op_C("log",a,out,author) return out -def rtf_rsqrt(a:Tensor, out:Tensor, author='miaobyte')->Tensor: - A_op_C("rsqrt",a,out,author) - return out - # 三角函数 def rtf_sin(a:Tensor, out:Tensor, author='miaobyte')->Tensor: A_op_C("sin",a,out,author) @@ -88,11 +84,7 @@ def rtf_tan(a:Tensor, out:Tensor, author='miaobyte')->Tensor: A_op_C("tan",a,out,author) return out -# 比较 -def rtf_compare(a:Tensor, b:Tensor, out:Tensor, author='miaobyte')->Tensor: - A_B_op_C("compare",a,b,out,author) - return out - +# 取大小值 def rtf_max(a:Tensor, b:Tensor, out:Tensor, author='miaobyte')->Tensor: A_B_op_C("max",a,b,out,author) return out @@ -109,10 +101,55 @@ def rtf_minscalar(a:Tensor, b:float, out:Tensor, author='miaobyte')->Tensor: A_scalar_op_C("minscalar",a,b,out,author) return out +# 位运算 def rtf_invert(a:Tensor, out:Tensor, author='miaobyte')->Tensor: A_op_C("invert",a,out,author) return out +#比较 +# A C 等价于 B>=A -> C +def rtf_less(a:Tensor, b:Tensor,out:Tensor, author='miaobyte')->Tensor: + A_B_op_C("less",a,b,out,author) + return out +# A>B -> C +def rtf_greater(a:Tensor, b:Tensor, out:Tensor, author='miaobyte')->Tensor: + A_B_op_C("greater",a,b,out,author) + return out +# A C 等价于 b>=A -> C +def rtf_lessscalar(a:Tensor, b:float,out:Tensor, author='miaobyte')->Tensor: + A_scalar_op_C("lessscalar",a,b,out,author) + return out +# A>b -> C +def rtf_greaterscalar(a:Tensor, b:float, out:Tensor, author='miaobyte')->Tensor: + A_scalar_op_C("greaterscalar",a,b,out,author) + return out + +# A==B -> C +def rtf_equal(a:Tensor, b:Tensor,epsilon:float, out:Tensor, author='miaobyte')->Tensor: + A_B_c_op_D("equal",a,b,epsilon,out,author) + return out +# A==b -> C +def rtf_equalscalar(a:Tensor, b:float,epsilon:float, out:Tensor, author='miaobyte')->Tensor: + A_scalar_c_op_D("equalscalar",a,b,epsilon,out,author) + return out +# A!=B -> C +def rtf_notequal(a:Tensor, b:Tensor,epsilon:float, out:Tensor, author='miaobyte')->Tensor: + A_B_c_op_D("notequal",a,b,epsilon,out,author) + return out +# A!=b -> C +def rtf_notequalscalar(a:Tensor, b:float,epsilon:float, out:Tensor, author='miaobyte')->Tensor: + A_scalar_c_op_D("notequalscalar",a,b,epsilon,out,author) + return out + +# 根据cases[index]的值tensoridx,从X[tensoridx]这个Tensor[index],赋值给out[index] +def rtf_switch(X:tuple[Tensor,...], cases:Tensor, out:Tensor, author='miaobyte')->Tensor: + args = [Param.listtensor(X),Param.tensor(cases)] + returns = [Param.tensor(out)] + ir = DeepxIR("switch", args, returns, author) + send(ir) + return out + + # 类型转换 def rtf_todtype(t:Tensor,dest:Tensor): assert isinstance(t,Tensor) diff --git a/front/py/deepx/scheduler/client/udpconn.py b/front/py/deepx/scheduler/client/udpconn.py index 6a12c26a..a25b0963 100644 --- a/front/py/deepx/scheduler/client/udpconn.py +++ b/front/py/deepx/scheduler/client/udpconn.py @@ -3,7 +3,7 @@ import select class UDPConn: - def __init__(self, endpoint: str = "localhost:9090"): + def __init__(self, endpoint: str = "localhost:8080"): # 解析endpoint self._host, port_str = endpoint.split(':') self._port = int(port_str) diff --git a/front/py/deepx/tensor/elementwise.py b/front/py/deepx/tensor/elementwise.py index 16ceb1dd..25de7456 100644 --- a/front/py/deepx/tensor/elementwise.py +++ b/front/py/deepx/tensor/elementwise.py @@ -2,6 +2,7 @@ from deepx.tensor import Tensor,tensor_method,Number +# 四则运算 @tensor_method def add(self, other:Union[Tensor,float,int], @@ -43,8 +44,6 @@ def div(self, other:Union[Tensor,float,int], from deepx.nn.functional import div as div_func return div_func(self,other,out) - - @tensor_method def div_(self, other:Union[Tensor,float,int]): from deepx.nn.functional import div as div_func @@ -64,6 +63,8 @@ def rdiv_(self, other:Union[float,int]): div_func(other,self,self) return self + +# 取最值 @tensor_method def min(self, other:Union[Tensor,float,int], out:Union[Tensor,str]='')->Tensor: @@ -102,7 +103,7 @@ def clamp_(self, min:Union[float,int], max:Union[float,int]): #todo pass - +# 幂指运算 @tensor_method def exp(self,out:Union[Tensor,str]='')->Tensor: from deepx.nn.functional import exp as exp_func @@ -157,15 +158,63 @@ def rsqrt_(self): from deepx.nn.functional import rsqrt as rsqrt_func rsqrt_func(self,self) +# 三角函数 +@tensor_method +def sin(self,out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import sin as sin_func + return sin_func(self,out) + +@tensor_method +def cos(self,out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import cos as cos_func + return cos_func(self,out) + +@tensor_method +def tan(self,out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import tan as tan_func + return tan_func(self,out) + + +# 位运算 @tensor_method def invert(self,out:Union[Tensor,str]='')->Tensor: from deepx.nn.functional import invert as invert_func return invert_func(self,out) +# 比较 +@tensor_method +def less(self,other:Union[Tensor,float,int],out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import less as less_func + return less_func(self,other,out) @tensor_method -def dropout_(self,p:float=0.5,seed:int=None): - from deepx.nn.functional import dropout as dropout_func - dropout_func(self,p,seed) - return self +def greater(self,other:Union[Tensor,float,int],out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import greater as greater_func + return greater_func(self,other,out) + +@tensor_method +def equal(self,other:Union[Tensor,float,int],epsilon:float=1e-6,out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import equal as equal_func + return equal_func(self,other,epsilon,out) + +@tensor_method +def notequal(self,other:Union[Tensor,float,int],epsilon:float=1e-6,out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import notequal as notequal_func + return notequal_func(self,other,epsilon,out) +# 分支 +@tensor_method +def switch(self,cases:Union[Tensor,float,int],out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import switch as switch_func + return switch_func(self,cases,out) + +@tensor_method +def switch_(self,cases:Union[Tensor,float,int]): + from deepx.nn.functional import switch as switch_func + switch_func(self,cases,self) + +# 类型转换 +@tensor_method +def todtype(self,dest:Union[Tensor,float,int]): + from deepx.nn.functional import todtype as todtype_func + return todtype_func(self,dest) \ No newline at end of file diff --git a/front/py/deepx/tensor/init.py b/front/py/deepx/tensor/init.py index 803ba866..cef6358c 100644 --- a/front/py/deepx/tensor/init.py +++ b/front/py/deepx/tensor/init.py @@ -1,11 +1,19 @@ from typing import Union from deepx.tensor import tensor_method +# 填充 @tensor_method def full_(self,value:Union[float,int]): from deepx.nn.functional import constant_ as constant_func constant_func(self,value=value) +@tensor_method +def dropout_(self,p:float=0.5,seed:int=None): + from deepx.nn.functional import dropout as dropout_func + dropout_func(self,p,seed) + return self + + @tensor_method def zeros_(self): from deepx.nn.functional import constant_ as constant_func diff --git a/front/py/deepx/tensor/tensor.py b/front/py/deepx/tensor/tensor.py index 11144d7a..18bd0caa 100644 --- a/front/py/deepx/tensor/tensor.py +++ b/front/py/deepx/tensor/tensor.py @@ -120,25 +120,37 @@ def __sub__(self, other:Union[Number,'Tensor']): return self.sub(other) def __rsub__(self, other:Union[Number,'Tensor']): return self.sub(other) - def __mul__(self, other:Union[Number,'Tensor']): return self.mul(other) def __rmul__(self, other:Union[Number,'Tensor']): return self.mul(other) def __truediv__(self, other:Union[Number,'Tensor']): return self.div(other) - def __rtruediv__(self, other:Union[Number,'Tensor']): return self.rdiv(other) - + # 幂指 def __pow__(self, other:Union[Number,'Tensor']): return self.pow(other) def __rpow__(self, other:Union[Number,'Tensor']): return self.rpow(other) - + # 位 def __invert__(self): return self.invert() + # 比较 + def __eq__(self, other:Union[Number,'Tensor']): + return self.equal(other) + def __ne__(self, other:Union[Number,'Tensor']): + return self.notequal(other) + def __gt__(self, other:Union[Number,'Tensor']): + return self.greater(other) + def __ge__(self, other:Union[Number,'Tensor']): + return other.less(self) + def __lt__(self, other:Union[Number,'Tensor']): + return self.less(other) + def __le__(self, other:Union[Number,'Tensor']): + return other.greater(self) + #矩阵乘法 def __matmul__(self, other:'Tensor'): return self.matmul(other) diff --git a/front/py/examples/2_ir/2_elementwise_compare.py b/front/py/examples/2_ir/2_elementwise_bit.py similarity index 100% rename from front/py/examples/2_ir/2_elementwise_compare.py rename to front/py/examples/2_ir/2_elementwise_bit.py diff --git a/front/py/examples/2_ir/2_elementwise_lessgreater.py b/front/py/examples/2_ir/2_elementwise_lessgreater.py new file mode 100644 index 00000000..6a6fde89 --- /dev/null +++ b/front/py/examples/2_ir/2_elementwise_lessgreater.py @@ -0,0 +1,39 @@ +############-------PyTorch-------################ + +print() +import torch +torch_t1 = torch.full((2,3,4, ), 10, dtype=torch.float32) +torch_t2 = torch.arange(24,dtype=torch.float32).reshape(2,3,4) +torch_t3= torch.less(torch_t2,torch_t1) +print("t1t2") +print(torch_t4) +torch_t5= torch.equal(torch_t2,torch_t1) +print("t1==t2") +print(torch_t5) +torch_t6= torch.not_equal(torch_t2,torch_t1) +print("t1!=t2") +print(torch_t6) + + +############-------DEEPX-------################ + +from deepx import Tensor,full,arange,less,greater + +print() + +t1 = full((2,3,4), value=10,dtype="float32") +equalmask=t1==10 +equalmask.print() +t2 = arange(0,24,dtype="float32").reshape_((2,3,4)) +t3_= t2t1 +t4_.print() + +t5_= t2==t1 +t5_.print() +t6_= t2!=t1 +t6_.print() diff --git a/front/py/examples/2_ir/2_elementwise_minmax.py b/front/py/examples/2_ir/2_elementwise_minmax.py new file mode 100644 index 00000000..a09baf14 --- /dev/null +++ b/front/py/examples/2_ir/2_elementwise_minmax.py @@ -0,0 +1,24 @@ +############-------PyTorch-------################ + +print() +import torch +torch_t1 = torch.full((2,3,4, ), 10, dtype=torch.int8) +torch_t2 = torch.arange(24,dtype=torch.int8).reshape(2,3,4) +torch_t3= torch.min(torch_t2,torch_t1) +print(torch_t3) +torch_t4= torch.max(torch_t2,torch_t1) +print(torch_t4) + + +############-------DEEPX-------################ + +from deepx import Tensor,full,arange,min,max + +print() + +t1 = full((2,3,4), value=10,dtype="int8") +t2 = arange(0,24,dtype="int8").reshape_((2,3,4)) +t3 = min(t2,t1) +t3.print() +t4 = max(t2,t1) +t4.print() \ No newline at end of file