Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# deepx

deepx提出了一种原生分布式自动并行的训推一体化的深度学习框架
deepx提出了一种以IR计算图为核心的原生分布式自动并行的训推一体化的深度学习框架,以IR计算图为核心,经过多层级等价替换,实现从简单的数学形式的计算图,自适应等价替换为分布式、并行、自动反向的工程系统架构

## 一.deepx概述

deepx的执行支持eager和auto两种模式
deepx的分为前端表达侧,编译替换层,执行器层

+ eager立即执行函数
+ auto则会经过计算图编译器优化器
+ 前端表达侧,交由算法工程师、用接近数学的表达方式,设计其数学计算过程。只表示为单节点、单线程的简介数学表达过程,不设计复杂的device类型、计算节点数量等。
+ 编译替换层:注册了多轮不同类型的IR编译器,实现等价替换,可以以插件的形式增加自定义能力如定制kvcache,实现对计算图进行局部替换,获得新的能力。
+ 执行器层:实现真正的tensor运算,大规模并行化。

### 前端

python sdk提供接近pytorch的API
也容许其他语言的sdk接入
也容许其他语言的sdk接入

+ IR通信调度。不同于pytorch或其他py+bind c++这种单一进程的栈上函数调度执行的方式。deepx各个程序(如front的python sdk,back的计算图编译器优化器、excuter如ompsimd)之间,通过IR实现网络通信调度,需要各自启动对应进程。

Expand All @@ -21,36 +22,43 @@ python sdk提供接近pytorch的API
|--------------|-----------------------|-------------------------|
| 执行模式 | 单进程内函数栈调度 | 多进程分布式协同 |
| 通信方式 | 内存直接访问 | IR网络计算调度协议交换 |
| 组件耦合度 | 紧耦合(Python绑定C++)| 松耦合(gRPC/自定义协议)|
| 组件耦合度 | 紧耦合(Python绑定C++)| 松耦合|
| tensor生命周期管理 | 由python侧控制 | 由deltensor这个IR指令,显示管理tensor|

### 调度面
### 编译替换层

+ 注册中心:收集当前已就绪的执行器的算子列表,收集算子时耗和空间占用信息
+ 计算图编译器优化器:fusion算子,计算图节点消除,自动生成tensor拆分并行的计算子图并替代原节点
+ 执行调度器:数据并行,流水线并行(前向反向并行),模型并行。
+ front生成基础IR,编译器负责进行fusion成excuter注册的高级算子。

### 执行器
### 执行层

负责低级的算子计算操作,以Op为执行的核心单元
执行层包括op和mem两种执行器,但实际实现时,当前只设计了一个程序同时负责op和mem的管理。

负责低级的算子计算操作,以IR为执行的核心单元
```
Op{args(args_grad),returns(returns_grad)|func forward,backward}
Op{args(args_grad),returns(returns_grad)|func run}
```

大部分Op都需要同时实现forward和backward,但也有部分只为推理设计的fusionOp可以根据需要实现forward。
Op需要实现run方法

关于excuter,只要能按deepxIR序列执行,并返回结果,就可以接入deepx分布式调度框架,因此,从硬件、指令、加速库、高级框架包括训练、推理引擎,都可以稍作修改,就接入deepx体系。

当前的


#### 默认执行器
+ cpu执行器,已实现ompsimd。其支持的算子列表[ompsimd](doc/excuter/op-mem-ompsimd/list.md)

#### GPU执行器
+ cuda执行器【实现中状态】
+ cuda执行器,其支持的算子列表[cuda](doc/excuter/op-mem-cuda/list.md)

欢迎大家提交cuda代码

+ rocm

+ apple
+ 其他硬件加速器

#### 张量计算框架or函数级执行器

Expand Down
8 changes: 5 additions & 3 deletions doc/excuter/op-mem-cuda/list.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
| Operation | Author | Math Formula | IR Instruction |
|-----------|--------|--------------|----------------|
| normal | miaobyte | normal(mean,stddev,seed)->T1 | normal(var<any> mean, var<any> stddev, var<int32> seed)->(tensor<any> t) |
| dropout | miaobyte | dropout(p,seed)->A | dropout(var<float32> p, var<int32> seed)->(tensor<any> A) |
| uniform | miaobyte | uniform(low,high,seed)->T1 | uniform(var<any> low, var<any> high, var<int32> seed)->(tensor<any> t) |
| arange | miaobyte | arange(start,step)->T1 | arange(var<any> start, var<any> step)->(tensor<any> t) |
| constant | miaobyte | constant(value)->T1 | constant(var<any> value)->(tensor<any> t) |
Expand All @@ -50,18 +51,19 @@
|-----------|--------|--------------|----------------|
| switch | miaobyte | C=switch(tensors,cases) | switch(listtensor<any> tensors, tensor<int8> cases)->(tensor<any> result) |
| greaterscalar | miaobyte | mask=compare(T1, scalar) | greaterscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
| equalscalar | miaobyte | mask=compare(T1, scalar) | equalscalar(tensor<any> A, var<any> scalar, var<float64> epsilon)->(tensor<bool> mask) |
| notequal | miaobyte | T1!=T2->mask | notequal(tensor<any> A, tensor<any> B, var<float32> epsilon)->(tensor<bool> mask) |
| equalscalar | miaobyte | T1==scalar->mask | equalscalar(tensor<any> A, var<any> scalar, var<float32> epsilon)->(tensor<bool> mask) |
| min | miaobyte | T3=min(T1, T2) | min(tensor<any> A, tensor<any> B)->(tensor<any> C) |
| maxscalar | miaobyte | T3=max(T1, scalar) | maxscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
| tan | miaobyte | T3=tan(T1) | tan(tensor<float64|float32> A)->(tensor<float64|float32> C) |
| sin | miaobyte | T3=sin(T1) | sin(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
| dropout | miaobyte | dropout(p,seed)->A | dropout(var<float32> p, var<int32> seed)->(tensor<any> A) |
| divscalar | miaobyte | T3=scalar/T1 | divscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
| log | miaobyte | T3=log(T1) | log(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
| addscalar | miaobyte | T3=T1+scalar | addscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
| greater | miaobyte | mask=compare(T1, T2) | greater(tensor<any> A, tensor<any> B)->(tensor<bool> mask) |
| lessscalar | miaobyte | mask=compare(T1, scalar) | lessscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
| cos | miaobyte | T3=cos(T1) | cos(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
| notequalscalar | miaobyte | T1!=scalar->mask | notequalscalar(tensor<any> A, var<any> scalar, var<float32> epsilon)->(tensor<bool> mask) |
| minscalar | miaobyte | T3=min(T1, scalar) | minscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
| rpowscalar | miaobyte | T3=pow(scalar, T1) | rpowscalar(var<float64|int32> scalar, tensor<float64|float32> A)->(tensor<float64|float32> C) |
| rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
Expand All @@ -75,7 +77,7 @@
| subscalar | miaobyte | T3=T1-scalar | subscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
| exp | miaobyte | T3=exp(T1) | exp(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
| mul | miaobyte | T3=T1*T2 | mul(tensor<any> A, tensor<any> B)->(tensor<any> C) |
| equal | miaobyte | mask=compare(T1, T2) | equal(tensor<any> A, tensor<any> B, var<float64> epsilon)->(tensor<bool> mask) |
| equal | miaobyte | T1==T2->mask | equal(tensor<any> A, tensor<any> B, var<float32> epsilon)->(tensor<bool> mask) |
| mulscalar | miaobyte | T3=T1*scalar | mulscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
| div | miaobyte | T3=T1/T2 | div(tensor<any> A, tensor<any> B)->(tensor<any> C) |
| invert | miaobyte | T3=~T1 | invert(tensor<int64|int32|int16|int8> A)->(tensor<int64|int32|int16|int8> C) |
Expand Down
8 changes: 5 additions & 3 deletions doc/excuter/op-mem-ompsimd/list.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
| Operation | Author | Math Formula | IR Instruction |
|-----------|--------|--------------|----------------|
| normal | miaobyte | normal(mean,stddev,seed)->T1 | normal(var<any> mean, var<any> std, var<int32> seed)->(tensor<any> t) |
| dropout | miaobyte | dropout(p,seed)->A | dropout(var<float32> p, var<int32> seed)->(tensor<any> A) |
| uniform | miaobyte | uniform(low,high,seed)->T1 | uniform(var<any> low, var<any> high, var<int32> seed)->(tensor<any> t) |
| arange | miaobyte | arange(start,step)->T1 | arange(var<any> start, var<any> step)->(tensor<any> t) |
| constant | miaobyte | constant(value)->T1 | constant(var<any> value)->(tensor<any> t) |
Expand All @@ -51,15 +52,16 @@
|-----------|--------|--------------|----------------|
| switch | miaobyte | C=switch([tensors],case) | switch(listtensor<any> tensors, tensor<int8> cases)->(tensor<any> C) |
| greaterscalar | miaobyte | mask=greater(T1,scalar) | greaterscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
| equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
| notequal | miaobyte | notequal(T1,T2)->mask | notequal(tensor<any> A, tensor<any> B, var<float32> epsilon)->(tensor<bool> mask) |
| equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor<any> A, var<any> scalar, var<float32> eposilon)->(tensor<bool> mask) |
| min | miaobyte | T3=min(T1,T2) | min(tensor<any> A, tensor<any> B)->(tensor<any> C) |
| maxscalar | miaobyte | T3=max(T1,scalar) | maxscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
| dropout | miaobyte | dropout(p,seed)->A | dropout(var<float32> p, var<int32> seed)->(tensor<any> A) |
| divscalar | miaobyte | T3=T1/scalar | divscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
| log | miaobyte | T3=log(T1) | log(tensor<any> A)->(tensor<any> C) |
| addscalar | miaobyte | T3=T1+scalar | addscalar(tensor<any> a, var<any> scalar)->(tensor<any> c) |
| greater | miaobyte | mask=greater(T1,T2) | greater(tensor<any> A, tensor<any> B)->(tensor<bool> mask) |
| lessscalar | miaobyte | mask=less(T1,scalar) | lessscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
| notequalscalar | miaobyte | mask=notequal(T1,scalar) | notequalscalar(tensor<any> A, var<any> scalar, var<float32> epsilon)->(tensor<bool> mask) |
| minscalar | miaobyte | T3=min(T1,scalar) | minscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
| rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
| rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
Expand All @@ -73,7 +75,7 @@
| subscalar | miaobyte | T3=T1-scalar | subscalar(tensor<any> a, var<any> scalar)->(tensor<any> c) |
| exp | miaobyte | T3=exp(T1) | exp(tensor<any> A)->(tensor<any> C) |
| mul | miaobyte | T3=T1*T2 | mul(tensor<any> A, tensor<any> B)->(tensor<any> C) |
| equal | miaobyte | mask=equal(T1,T2) | equal(tensor<any> A, tensor<any> B)->(tensor<bool> mask) |
| equal | miaobyte | equal(T1,T2)->mask | equal(tensor<any> A, tensor<any> B, var<float32> eposilon)->(tensor<bool> mask) |
| mulscalar | miaobyte | T3=T1*scalar | mulscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
| div | miaobyte | T3=T1/T2 | div(tensor<any> A, tensor<any> B)->(tensor<any> C) |
| invert | miaobyte | T3=~T1 | invert(tensor<int64|int32|int16|int8> A)->(tensor<int64|int32|int16|int8> C) |
Expand Down
25 changes: 25 additions & 0 deletions excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,31 @@ namespace deepx::tensorfunc
{
equalscalarDispatcher<Author, T, MaskT>::equalscalar(A, scalar, epsilon, mask);
}
//notequal(A,B)=>mask
template <typename Author, typename T, typename MaskT>
struct notequalDispatcher
{
static void notequal(const Tensor<T> &A, const Tensor<T> &B,const float epsilon, Tensor<MaskT> &mask) = delete;
};

template <typename Author, typename T, typename MaskT>
void notequal(const Tensor<T> &A, const Tensor<T> &B,const float epsilon, Tensor<MaskT> &mask)
{
notequalDispatcher<Author, T, MaskT>::notequal(A, B, epsilon, mask);
}

// notequal(A,scalar)=>mask
template <typename Author, typename T, typename MaskT>
struct notequalscalarDispatcher
{
static void notequalscalar(const Tensor<T> &A, const T scalar,const float epsilon, Tensor<MaskT> &mask) = delete;
};

template <typename Author, typename T, typename MaskT>
void notequalscalar(const Tensor<T> &A, const T scalar,const float epsilon, Tensor<MaskT> &mask)
{
notequalscalarDispatcher<Author, T, MaskT>::notequalscalar(A, scalar, epsilon, mask);
}

// less(A,B)=>mask
template <typename Author, typename T, typename MaskT>
Expand Down
24 changes: 22 additions & 2 deletions excuter/op-mem-cuda/src/client/tfs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ namespace deepx::tf
{
Param("A", DataCategory::Tensor, Precision::Any),
Param("B", DataCategory::Tensor, Precision::Any),
Param("epsilon", DataCategory::Var, Precision::Float64),
Param("epsilon", DataCategory::Var, Precision::Float32),
}),
vector<Param>(
{
Expand All @@ -416,7 +416,27 @@ namespace deepx::tf
{
Param("A", DataCategory::Tensor, Precision::Any),
Param("scalar", DataCategory::Var, Precision::Any),
Param("epsilon", DataCategory::Var, Precision::Float64),
Param("epsilon", DataCategory::Var, Precision::Float32),
}),
vector<Param>(
{
Param("mask", DataCategory::Tensor, Precision::Bool),
})));
tffactory.add_tf(std::make_shared<NotEqual<miaobyte>>(vector<Param>(
{
Param("A", DataCategory::Tensor, Precision::Any),
Param("B", DataCategory::Tensor, Precision::Any),
Param("epsilon", DataCategory::Var, Precision::Float32),
}),
vector<Param>(
{
Param("mask", DataCategory::Tensor, Precision::Bool),
})));
tffactory.add_tf(std::make_shared<NotEqualScalar<miaobyte>>(vector<Param>(
{
Param("A", DataCategory::Tensor, Precision::Any),
Param("scalar", DataCategory::Var, Precision::Any),
Param("epsilon", DataCategory::Var, Precision::Float32),
}),
vector<Param>(
{
Expand Down
7 changes: 6 additions & 1 deletion excuter/op-mem-cuda/src/deepx/mem/mem_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ namespace deepx::mem
result->data = ptr_tensor->data;
break;
}

case Precision::Bool:
{
auto ptr_tensor = std::static_pointer_cast<Tensor<bool>>(ptr);
result->data = ptr_tensor->data;
break;
}
default:
throw std::runtime_error("Unsupported dtype: " + precision_str(ptr->shape.dtype));
}
Expand Down
Loading
Loading