diff --git a/doc/excuter/op-mem-cuda/list.md b/doc/excuter/op-mem-cuda/list.md index f487b09c..9c4d4395 100644 --- a/doc/excuter/op-mem-cuda/list.md +++ b/doc/excuter/op-mem-cuda/list.md @@ -49,7 +49,7 @@ | Operation | Author | Math Formula | IR Instruction | |-----------|--------|--------------|----------------| -| switch | miaobyte | C=switch(tensors,cases) | switch(listtensor tensors, tensor cases)->(tensor result) | +| switch | miaobyte | C=switch(tensors,cases) | switch(listtensor tensors, tensor cases)->(tensor result) | | greaterscalar | miaobyte | mask=compare(T1, scalar) | greaterscalar(tensor A, var scalar)->(tensor mask) | | notequal | miaobyte | T1!=T2->mask | notequal(tensor A, tensor B, var epsilon)->(tensor mask) | | equalscalar | miaobyte | T1==scalar->mask | equalscalar(tensor A, var scalar, var epsilon)->(tensor mask) | diff --git a/doc/excuter/op-mem-ompsimd/list.md b/doc/excuter/op-mem-ompsimd/list.md index d2ba1745..03296c08 100644 --- a/doc/excuter/op-mem-ompsimd/list.md +++ b/doc/excuter/op-mem-ompsimd/list.md @@ -50,7 +50,7 @@ | Operation | Author | Math Formula | IR Instruction | |-----------|--------|--------------|----------------| -| switch | miaobyte | C=switch([tensors],case) | switch(listtensor tensors, tensor cases)->(tensor C) | +| switch | miaobyte | C=switch([tensors],case) | switch(listtensor tensors, tensor cases)->(tensor C) | | greaterscalar | miaobyte | mask=greater(T1,scalar) | greaterscalar(tensor A, var scalar)->(tensor mask) | | notequal | miaobyte | notequal(T1,T2)->mask | notequal(tensor A, tensor B, var epsilon)->(tensor mask) | | equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor A, var scalar, var eposilon)->(tensor mask) | @@ -63,7 +63,7 @@ | lessscalar | miaobyte | mask=less(T1,scalar) | lessscalar(tensor A, var scalar)->(tensor mask) | | notequalscalar | miaobyte | mask=notequal(T1,scalar) | notequalscalar(tensor A, var scalar, var epsilon)->(tensor mask) | | minscalar | miaobyte | T3=min(T1,scalar) | minscalar(tensor A, var scalar)->(tensor C) | -| rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var scalar, tensor A)->(tensor C) | +| rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var scalar, tensor A)->(tensor C) | | rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var scalar, tensor A)->(tensor C) | | less | miaobyte | mask=less(T1,T2) | less(tensor A, tensor B)->(tensor mask) | | powscalar | miaobyte | T3=T1^scalar | powscalar(tensor A, var scalar)->(tensor C) | diff --git a/excuter/op-mem-cuda/src/client/tfs.cpp b/excuter/op-mem-cuda/src/client/tfs.cpp index 0f64b87b..1c1156b3 100644 --- a/excuter/op-mem-cuda/src/client/tfs.cpp +++ b/excuter/op-mem-cuda/src/client/tfs.cpp @@ -486,7 +486,7 @@ namespace deepx::tf tffactory.add_tf(std::make_shared>(vector( { Param("tensors", DataCategory::ListTensor, Precision::Any), - Param("cases", DataCategory::Tensor, Precision::Int8), + Param("cases", DataCategory::Tensor, Precision::Int32|Precision::Bool), }), vector( { diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu index 72c86352..a33a8722 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_compare.cu @@ -524,15 +524,25 @@ namespace deepx::tensorfunc } } - template void launch_switch(const double **tensorsdata, const int numTensors, const int8_t *cases, double *C, const int size); - template void launch_switch(const float **tensorsdata, const int numTensors, const int8_t *cases, float *C, const int size); - template void launch_switch(const nv_bfloat16 **tensorsdata, const int numTensors, const int8_t *cases, nv_bfloat16 *C, const int size); - template void launch_switch<__half,int8_t>(const __half **tensorsdata, const int numTensors, const int8_t *cases, __half *C, const int size); - template void launch_switch(const int64_t **tensorsdata, const int numTensors, const int8_t *cases, int64_t *C, const int size); - template void launch_switch(const int32_t **tensorsdata, const int numTensors, const int8_t *cases, int32_t *C, const int size); - template void launch_switch(const int16_t **tensorsdata, const int numTensors, const int8_t *cases, int16_t *C, const int size); - template void launch_switch(const int8_t **tensorsdata, const int numTensors, const int8_t *cases, int8_t *C, const int size); - template void launch_switch(const bool **tensorsdata, const int numTensors, const int8_t *cases, bool *C, const int size); + template void launch_switch(const double **tensorsdata, const int numTensors, const int32_t *cases, double *C, const int size); + template void launch_switch(const float **tensorsdata, const int numTensors, const int32_t *cases, float *C, const int size); + template void launch_switch(const nv_bfloat16 **tensorsdata, const int numTensors, const int32_t *cases, nv_bfloat16 *C, const int size); + template void launch_switch<__half,int32_t>(const __half **tensorsdata, const int numTensors, const int32_t *cases, __half *C, const int size); + template void launch_switch(const int64_t **tensorsdata, const int numTensors, const int32_t *cases, int64_t *C, const int size); + template void launch_switch(const int32_t **tensorsdata, const int numTensors, const int32_t *cases, int32_t *C, const int size); + template void launch_switch(const int16_t **tensorsdata, const int numTensors, const int32_t *cases, int16_t *C, const int size); + template void launch_switch(const int8_t **tensorsdata, const int numTensors, const int32_t *cases, int8_t *C, const int size); + template void launch_switch(const bool **tensorsdata, const int numTensors, const int32_t *cases, bool *C, const int size); + + template void launch_switch(const double **tensorsdata, const int numTensors, const bool *cases, double *C, const int size); + template void launch_switch(const float **tensorsdata, const int numTensors, const bool *cases, float *C, const int size); + template void launch_switch(const nv_bfloat16 **tensorsdata, const int numTensors, const bool *cases, nv_bfloat16 *C, const int size); + template void launch_switch<__half,bool>(const __half **tensorsdata, const int numTensors, const bool *cases, __half *C, const int size); + template void launch_switch(const int64_t **tensorsdata, const int numTensors, const bool *cases, int64_t *C, const int size); + template void launch_switch(const int32_t **tensorsdata, const int numTensors, const bool *cases, int32_t *C, const int size); + template void launch_switch(const int16_t **tensorsdata, const int numTensors, const bool *cases, int16_t *C, const int size); + template void launch_switch(const int8_t **tensorsdata, const int numTensors, const bool *cases, int8_t *C, const int size); + template void launch_switch(const bool **tensorsdata, const int numTensors, const bool *cases, bool *C, const int size); } #endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_COMPARE_CU diff --git a/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp b/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp index cce24d81..77a4b8b1 100644 --- a/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp +++ b/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp @@ -835,35 +835,99 @@ namespace deepx::tf { Precision C_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; - + Precision cases_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; + switch (C_type) { case Precision::Float64: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); - break; - case Precision::Float32: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + break; + case Precision::Float32: + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Float16: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::BFloat16: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Int64: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Int32: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Int16: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Int8: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Bool: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)),*mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)),*mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; default: error = "Unsupported type: " + precision_str(C_type); diff --git a/excuter/op-mem-ompsimd/src/client/tfs.cpp b/excuter/op-mem-ompsimd/src/client/tfs.cpp index 902393da..8ebb1d95 100644 --- a/excuter/op-mem-ompsimd/src/client/tfs.cpp +++ b/excuter/op-mem-ompsimd/src/client/tfs.cpp @@ -492,7 +492,7 @@ namespace deepx::tf tffactory.add_tf(std::make_shared>(vector( { Param("tensors", DataCategory::ListTensor, Precision::Any), - Param("cases", DataCategory::Tensor, Precision::Int8), + Param("cases", DataCategory::Tensor, Precision::Bool|Precision::Int32), }), vector( { diff --git a/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp b/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp index 9be9f3a4..f6f3831a 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp @@ -995,8 +995,8 @@ namespace deepx::tensorfunc { for (int j = 0; j < i_end; j++) { - int which_tensor=cases.data[i]; - C.data[i+j]=tensors[which_tensor]->data[i]; + int which_tensor=cases.data[i+j]; + C.data[i+j]=tensors[which_tensor]->data[i+j]; } }); } else diff --git a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp index e0f00c9f..cc3117b9 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp @@ -1991,32 +1991,74 @@ namespace deepx::tf int run(shared_ptr mem, string &error) override { Precision cases_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; - Precision C_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; - if (cases_type != Precision::Int8 ) + Precision output_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; + if (cases_type != Precision::Int32 && cases_type != Precision::Bool) { - error = "Type mismatch: " + precision_str(cases_type) + " != int8"; + error = "Type mismatch: " + precision_str(cases_type) + " != Int32 or Bool"; return 1; } - switch (cases_type) + switch (output_type) { case Precision::Float64: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Float32: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Int64: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Int32: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; - case Precision::Int16: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + case Precision::Int16: + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; case Precision::Int8: - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + if (cases_type == Precision::Bool) + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } + else + { + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + } break; default: error = "Unsupported dtype: " + precision_str(cases_type); diff --git a/front/py/deepx/nn/functional/authormap.py b/front/py/deepx/nn/functional/authormap.py index 89142df2..834bda0a 100644 --- a/front/py/deepx/nn/functional/authormap.py +++ b/front/py/deepx/nn/functional/authormap.py @@ -33,6 +33,7 @@ 'equalscalar': 'miaobyte', 'notequal': 'miaobyte', 'notequalscalar': 'miaobyte', + 'switch':'miaobyte', # 'exp':'miaobyte', 'log':'miaobyte', diff --git a/front/py/deepx/nn/functional/elementwise.py b/front/py/deepx/nn/functional/elementwise.py index 52a6dc52..71b372bc 100644 --- a/front/py/deepx/nn/functional/elementwise.py +++ b/front/py/deepx/nn/functional/elementwise.py @@ -56,5 +56,5 @@ def bool(input:Tensor)->Tensor: return todtype(input,dest) def where(condition:Tensor,x:Tensor,y:Tensor)->Tensor: - from .leaffunc_elementwise import switch_func - return switch_func((x,y),condition) \ No newline at end of file + from .leaffunc_elementwise import switch as switch_func + return switch_func((y,x),condition) \ No newline at end of file diff --git a/front/py/deepx/nn/functional/leaffunc_elementwise.py b/front/py/deepx/nn/functional/leaffunc_elementwise.py index 4a6e26fd..a883a6e8 100644 --- a/front/py/deepx/nn/functional/leaffunc_elementwise.py +++ b/front/py/deepx/nn/functional/leaffunc_elementwise.py @@ -69,7 +69,7 @@ def switch(X:tuple[Tensor,...], cases:Tensor, out:Union[Tensor,str]=None)->Tenso assert isinstance(x,Tensor) and x.shape==cases.shape outtensor=out if isinstance(out,str) or out is None: - outtensor=newtensor(cases.shape,dtype=cases.dtype,name=out) + outtensor=newtensor(cases.shape,dtype=X[0].dtype,name=out) assert isinstance(outtensor,Tensor) and outtensor.shape==cases.shape from .rtf_elementwise import rtf_switch diff --git a/front/py/examples/2_ir/2_elementwise_switchwhere.py b/front/py/examples/2_ir/2_elementwise_switchwhere.py new file mode 100644 index 00000000..36db4a3a --- /dev/null +++ b/front/py/examples/2_ir/2_elementwise_switchwhere.py @@ -0,0 +1,22 @@ +############-------PyTorch-------################ + +print() +import torch +torch_t1 = torch.full((2,3,4, ), 10, dtype=torch.float32) +torch_t2 = torch.arange(24,dtype=torch.float32).reshape(2,3,4) +torch_t3= torch.where(torch_t2