Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/excuter/op-mem-cuda/list.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

| Operation | Author | Func Def | Math Formula | IR Instruction |
|-----------|--------|------------|--------------|----------------|
| reducemax | miaobyte | reducemax(tensor<any> A, vector<int32> dims, var<bool> keepdims)->(tensor<any> B) | B = reducemax(A, axis=[1 2], keepdims=false) | reducemax(tensor<any> A, vector<int32> dims, var<bool> keepdims)->(tensor<any> B) |
| broadcastTo | miaobyte | broadcastTo(tensor<any> A, vector<int32> new_shape)->(tensor<any> B) | T2 = T1.broadcastTo(new_shape=[4,3,2]) | broadcastTo(tensor<any> A, vector<int32> new_shape)->(tensor<any> B) |
| concat | miaobyte | concat(listtensor<any> tensors, var<int32> axis)->(tensor<any> result) | Tresult = concat([T1, T2...], axis=3) | concat(listtensor<any> tensors, var<int32> axis)->(tensor<any> result) |
| transpose | miaobyte | transpose(tensor<any> A, vector<int32> dim_order)->(tensor<any> C) | T2 = T1.transpose(dimorder=[1,0]) | transpose(tensor<any> A, vector<int32> dim_order)->(tensor<any> C) |
Expand All @@ -24,8 +25,10 @@
| newtensor | none | newtensor(vector<int32> shape)->(tensor<any> tensor1) | T1 = zeros(shape) | newtensor(vector<int32> shape)->(tensor<any> tensor1) |
| newtensor | none | newtensor(var<string> shape)->(tensor<any> tensor1) | T1 = zeros(shape) | newtensor(var<string> shape)->(tensor<any> tensor1) |
| vecset | none | vecset(vector<any> value)->(vector<any> name) | shape = [3 4 5] | vecset(vector<any> value)->(vector<any> name) |
| reducemin | miaobyte | reducemin(tensor<any> A, vector<int32> dims, var<bool> keepdims)->(tensor<any> B) | B = reducemin(A, axis=[1 2], keepdims=false) | reducemin(tensor<any> A, vector<int32> dims, var<bool> keepdims)->(tensor<any> B) |
| subscalar | miaobyte | subscalar(tensor<any> A, var<any> b)->(tensor<any> C) | T3=T1-scalar | subscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
| sqrt | miaobyte | sqrt(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) | T3=sqrt(T1) | sqrt(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
| sum | miaobyte | sum(tensor<any> A, vector<int32> dims, var<bool> keepdims)->(tensor<any> B) | B = sum(A, axis=[1 2], keepdims=false) | sum(tensor<any> A, vector<int32> dims, var<bool> keepdims)->(tensor<any> B) |
| argset | none | argset(var<any> value)->(var<any> name) | var argname = argvalue | argset(var<any> value)->(var<any> name) |
| sub | miaobyte | sub(tensor<any> A, tensor<any> B)->(tensor<any> C) | T3=T1-T2 | sub(tensor<any> A, tensor<any> B)->(tensor<any> C) |
| mulscalar | miaobyte | mulscalar(tensor<any> A, var<any> b)->(tensor<any> C) | T3=T1*scalar | mulscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
Expand All @@ -40,5 +43,6 @@
| rdivscalar | miaobyte | rdivscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) | T3=scalar/T1 | rdivscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
| minscalar | miaobyte | minscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) | T3=min(T1, scalar) | minscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
| cos | miaobyte | cos(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) | T3=cos(T1) | cos(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
| prod | miaobyte | prod(tensor<any> A, vector<int32> dims, var<bool> keepdims)->(tensor<any> B) | B = prod(A, axis=[1 2], keepdims=false) | prod(tensor<any> A, vector<int32> dims, var<bool> keepdims)->(tensor<any> B) |
| min | miaobyte | min(tensor<any> A, tensor<any> B)->(tensor<any> C) | T3=min(T1, T2) | min(tensor<any> A, tensor<any> B)->(tensor<any> C) |
| compare | miaobyte | compare(tensor<any> A, tensor<any> B)->(tensor<int8> mask) | mask=compare(T1, T2) | compare(tensor<any> A, tensor<any> B)->(tensor<int8> mask) |
24 changes: 12 additions & 12 deletions excuter/cpp-common/src/deepx/tensorfunc/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,46 @@ namespace deepx::tensorfunc
template <typename Author, typename T>
struct reducemaxDispatcher
{
static void reducemax(const Tensor<T> &A, const std::vector<int> &dims,Tensor<T> &B,const bool keepdims=false) = delete;
static void reducemax(const Tensor<T> &A, const std::vector<int> &dims,const bool keepdims,Tensor<T> &B) = delete;
};
template <typename Author, typename T>
void reducemax(const Tensor<T> &A, const std::vector<int> &dims,Tensor<T> &B,const bool keepdims=false)
void reducemax(const Tensor<T> &A, const std::vector<int> &dims,const bool keepdims,Tensor<T> &B)
{
reducemaxDispatcher<Author, T>::reducemax(A, dims, B, keepdims);
reducemaxDispatcher<Author, T>::reducemax(A, dims, keepdims, B);
}

template <typename Author, typename T>
struct reduceminDispatcher
{
static void reducemin(const Tensor<T> &A, const std::vector<int> &dims,Tensor<T> &B,const bool keepdims=false) = delete;
static void reducemin(const Tensor<T> &A, const std::vector<int> &dims,const bool keepdims,Tensor<T> &B) = delete;
};
template <typename Author, typename T>
void reducemin(const Tensor<T> &A, const std::vector<int> &dims,Tensor<T> &B,const bool keepdims=false)
void reducemin(const Tensor<T> &A, const std::vector<int> &dims,const bool keepdims,Tensor<T> &B)
{
reduceminDispatcher<Author, T>::reducemin(A, dims, B, keepdims);
reduceminDispatcher<Author, T>::reducemin(A, dims, keepdims, B);
}

template <typename Author, typename T>
struct sumDispatcher
{
static void reducesum(const Tensor<T> &A, const std::vector<int> &dims,Tensor<T> &B,const bool keepdims=false) = delete;
static void sum(const Tensor<T> &A, const std::vector<int> &dims,const bool keepdims,Tensor<T> &B) = delete;
};
template <typename Author, typename T>
void sum(const Tensor<T> &A, const std::vector<int> &dims,Tensor<T> &B,const bool keepdims=false)
void sum(const Tensor<T> &A, const std::vector<int> &dims,const bool keepdims,Tensor<T> &B)
{
sumDispatcher<Author, T>::sum(A, dims, B, keepdims);
sumDispatcher<Author, T>::sum(A, dims, keepdims, B);
}

template <typename Author, typename T>
struct prodDispatcher
{
static void prod(const Tensor<T> &A, const std::vector<int> &dims,Tensor<T> &B,const bool keepdims=false) = delete;
static void prod(const Tensor<T> &A, const std::vector<int> &dims,const bool keepdims,Tensor<T> &B) = delete;
};

template <typename Author, typename T>
void prod(const Tensor<T> &A, const std::vector<int> &dims,Tensor<T> &B,const bool keepdims=false)
void prod(const Tensor<T> &A, const std::vector<int> &dims,const bool keepdims,Tensor<T> &B)
{
prodDispatcher<Author, T>::prod(A, dims, B, keepdims);
prodDispatcher<Author, T>::prod(A, dims, keepdims, B);
}
}
#endif // DEEPX_TENSORFUNC_REDUCE_HPP
67 changes: 52 additions & 15 deletions excuter/op-mem-cuda/src/client/tfs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "deepx/tf/elementwise_compare.hpp"
#include "deepx/tf/matmul.hpp"
#include "deepx/tf/changeshape.hpp"
#include "deepx/tf/reduce.hpp"
#include "deepx/dtype.hpp"
#include "deepx/tf/tffactory.hpp"
#include "deepx/tensorfunc/authors.hpp"
Expand Down Expand Up @@ -371,20 +372,56 @@ namespace deepx::tf
Param("B", DataCategory::Tensor, Precision::Any),
})));
}
// // reduce
// void register_reduce(OpFactory &opfactory)
// {
// opfactory.add_op(Max<float>());
// opfactory.add_op(Max<double>());
// opfactory.add_op(Maxscalar<float>());
// opfactory.add_op(Maxscalar<double>());
// opfactory.add_op(Min<float>());
// opfactory.add_op(Min<double>());
// opfactory.add_op(Minscalar<float>());
// opfactory.add_op(Minscalar<double>());
// opfactory.add_op(Sum<float>());
// opfactory.add_op(Sum<double>());
// }
// reduce
void register_reduce(TfFactory &tffactory)
{
// sum
tffactory.add_tf(std::make_shared<Sum<miaobyte>>(vector<Param>(
{
Param("A", DataCategory::Tensor, Precision::Any),
Param("dims", DataCategory::Vector, Precision::Int32),
Param("keepdims", DataCategory::Var, Precision::Bool),
}),
vector<Param>(
{
Param("B", DataCategory::Tensor, Precision::Any),
})));
// prod
tffactory.add_tf(std::make_shared<Prod<miaobyte>>(vector<Param>(
{
Param("A", DataCategory::Tensor, Precision::Any),
Param("dims", DataCategory::Vector, Precision::Int32),
Param("keepdims", DataCategory::Var, Precision::Bool),
}),
vector<Param>(
{
Param("B", DataCategory::Tensor, Precision::Any),
})));

// max
tffactory.add_tf(std::make_shared<ReduceMax<miaobyte>>(vector<Param>(
{
Param("A", DataCategory::Tensor, Precision::Any),
Param("dims", DataCategory::Vector, Precision::Int32),
Param("keepdims", DataCategory::Var, Precision::Bool),
}),
vector<Param>(
{
Param("B", DataCategory::Tensor, Precision::Any),
})));
// min
tffactory.add_tf(std::make_shared<ReduceMin<miaobyte>>(vector<Param>(
{
Param("A", DataCategory::Tensor, Precision::Any),
Param("dims", DataCategory::Vector, Precision::Int32),
Param("keepdims", DataCategory::Var, Precision::Bool),
}),
vector<Param>(
{
Param("B", DataCategory::Tensor, Precision::Any),
})));
}

int register_all(TfFactory &tffactory)
{
register_lifecycle(tffactory);
Expand All @@ -393,7 +430,7 @@ namespace deepx::tf
register_elementwise(tffactory);
register_matmul(tffactory);
register_changeshape(tffactory);
// register_reduce(opfactory);
register_reduce(tffactory);
return 0;
}
}
37 changes: 11 additions & 26 deletions excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,10 @@ namespace deepx::tensorfunc
}
}

inline int nextPowerOf2(int n)
{
if (n <= 0)
return 1;
if ((n & (n - 1)) == 0)
return n; // 如果n已经是2的幂

n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
return n + 1;
}


template <typename T>
void launch_transpose(const int numBlocks, const int blockSize,
const T *input,
void launch_transpose(const T *input,
const int *inputStrides,
T *output,
const int *outputStrides,
Expand All @@ -72,7 +57,7 @@ namespace deepx::tensorfunc
cudaVector<int> dimOrder_d(dimOrder, dim);

int powDim = nextPowerOf2(dim);

auto [numBlocks, blockSize] = BestDims(len);
// 根据计算出的2的幂次选择对应的模板实例
switch (powDim)
{
Expand Down Expand Up @@ -105,14 +90,14 @@ namespace deepx::tensorfunc
}
}

template void launch_transpose<double>(const int numBlocks, const int blockSize, const double *input, const int *inputStrides, double *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<float>(const int numBlocks, const int blockSize, const float *input, const int *inputStrides, float *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<nv_bfloat16>(const int numBlocks, const int blockSize, const nv_bfloat16 *input, const int *inputStrides, nv_bfloat16 *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<__half>(const int numBlocks, const int blockSize, const __half *input, const int *inputStrides, __half *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<int64_t>(const int numBlocks, const int blockSize, const int64_t *input, const int *inputStrides, int64_t *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<int32_t>(const int numBlocks, const int blockSize, const int32_t *input, const int *inputStrides, int32_t *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<int16_t>(const int numBlocks, const int blockSize, const int16_t *input, const int *inputStrides, int16_t *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<int8_t>(const int numBlocks, const int blockSize, const int8_t *input, const int *inputStrides, int8_t *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<double>(const double *input, const int *inputStrides, double *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<float>(const float *input, const int *inputStrides, float *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<nv_bfloat16>(const nv_bfloat16 *input, const int *inputStrides, nv_bfloat16 *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<__half>(const __half *input, const int *inputStrides, __half *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<int64_t>(const int64_t *input, const int *inputStrides, int64_t *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<int32_t>(const int32_t *input, const int *inputStrides, int32_t *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<int16_t>(const int16_t *input, const int *inputStrides, int16_t *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);
template void launch_transpose<int8_t>(const int8_t *input, const int *inputStrides, int8_t *output, const int *outputStrides, const int dim, const int len, const int *dimOrder);

// concat
template <int DIM, typename T>
Expand Down
Loading