diff --git a/.github/ISSUE_TEMPLATE/operator.md b/.github/ISSUE_TEMPLATE/operator.md new file mode 100644 index 00000000..3445460f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/operator.md @@ -0,0 +1,29 @@ +--- +name: 算子新增 +about: 用于提交新的算子实现请求 +title: '[算子] ' +labels: enhancement, operator +assignees: '' +--- + +## 算子新增 +该算子数学表达为 + +## 影响组件 + +### front +1. +2. + +### 引擎 +1. +2. + +## 其他叙述 + + \ No newline at end of file diff --git a/doc/excuter/op-mem-cuda/list.md b/doc/excuter/op-mem-cuda/list.md index 9c4d4395..f11918d9 100644 --- a/doc/excuter/op-mem-cuda/list.md +++ b/doc/excuter/op-mem-cuda/list.md @@ -80,7 +80,7 @@ | equal | miaobyte | T1==T2->mask | equal(tensor A, tensor B, var epsilon)->(tensor mask) | | mulscalar | miaobyte | T3=T1*scalar | mulscalar(tensor A, var b)->(tensor C) | | div | miaobyte | T3=T1/T2 | div(tensor A, tensor B)->(tensor C) | -| invert | miaobyte | T3=~T1 | invert(tensor A)->(tensor C) | +| invert | miaobyte | T3=~T1 | invert(tensor A)->(tensor C) | | max | miaobyte | T3=max(T1, T2) | max(tensor A, tensor B)->(tensor C) | | pow | miaobyte | T3=pow(T1, T2) | pow(tensor A, tensor B)->(tensor C) | diff --git a/doc/excuter/op-mem-ompsimd/list.md b/doc/excuter/op-mem-ompsimd/list.md index 03296c08..9e16aea4 100644 --- a/doc/excuter/op-mem-ompsimd/list.md +++ b/doc/excuter/op-mem-ompsimd/list.md @@ -56,11 +56,14 @@ | equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor A, var scalar, var eposilon)->(tensor mask) | | min | miaobyte | T3=min(T1,T2) | min(tensor A, tensor B)->(tensor C) | | maxscalar | miaobyte | T3=max(T1,scalar) | maxscalar(tensor A, var scalar)->(tensor C) | +| tan | miaobyte | T3=tan(T1) | tan(tensor A)->(tensor C) | +| sin | miaobyte | T3=sin(T1) | sin(tensor A)->(tensor C) | | divscalar | miaobyte | T3=T1/scalar | divscalar(tensor A, var scalar)->(tensor C) | | log | miaobyte | T3=log(T1) | log(tensor A)->(tensor C) | | addscalar | miaobyte | T3=T1+scalar | addscalar(tensor a, var scalar)->(tensor c) | | greater | miaobyte | mask=greater(T1,T2) | greater(tensor A, tensor B)->(tensor mask) | | lessscalar | miaobyte | mask=less(T1,scalar) | lessscalar(tensor A, var scalar)->(tensor mask) | +| cos | miaobyte | T3=cos(T1) | cos(tensor A)->(tensor C) | | notequalscalar | miaobyte | mask=notequal(T1,scalar) | notequalscalar(tensor A, var scalar, var epsilon)->(tensor mask) | | minscalar | miaobyte | T3=min(T1,scalar) | minscalar(tensor A, var scalar)->(tensor C) | | rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var scalar, tensor A)->(tensor C) | @@ -78,7 +81,7 @@ | equal | miaobyte | equal(T1,T2)->mask | equal(tensor A, tensor B, var eposilon)->(tensor mask) | | mulscalar | miaobyte | T3=T1*scalar | mulscalar(tensor A, var b)->(tensor C) | | div | miaobyte | T3=T1/T2 | div(tensor A, tensor B)->(tensor C) | -| invert | miaobyte | T3=~T1 | invert(tensor A)->(tensor C) | +| invert | miaobyte | T3=~T1 | invert(tensor A)->(tensor C) | | max | miaobyte | T3=max(T1,T2) | max(tensor A, tensor B)->(tensor C) | | pow | miaobyte | T3=T1^T2 | pow(tensor A, tensor B)->(tensor C) | diff --git a/excuter/op-mem-cuda/src/client/tfs.cpp b/excuter/op-mem-cuda/src/client/tfs.cpp index 1c1156b3..1192bb74 100644 --- a/excuter/op-mem-cuda/src/client/tfs.cpp +++ b/excuter/op-mem-cuda/src/client/tfs.cpp @@ -280,11 +280,11 @@ namespace deepx::tf // invert tffactory.add_tf(std::make_shared>(vector( { - Param("A", DataCategory::Tensor, Precision::Int64 | Precision::Int32 | Precision::Int16 | Precision::Int8), + Param("A", DataCategory::Tensor, Precision::Int64 | Precision::Int32 | Precision::Int16 | Precision::Int8|Precision::Bool), }), vector( { - Param("C", DataCategory::Tensor, Precision::Int64 | Precision::Int32 | Precision::Int16 | Precision::Int8), + Param("C", DataCategory::Tensor, Precision::Int64 | Precision::Int32 | Precision::Int16 | Precision::Int8|Precision::Bool), }))); tffactory.add_tf(std::make_shared>(vector( diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda.hpp b/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda.hpp index d9c9f3c1..5f4b3cd4 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda.hpp +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda.hpp @@ -80,6 +80,12 @@ namespace deepx::tensorfunc return {size, host_data}; } + inline void throwcudaerror(const std::string& msg,cudaError_t err){ + if (err != cudaSuccess) + { + throw std::runtime_error(msg + "\n" + std::string(cudaGetErrorString(err))); + } + } } #endif diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu index 7ee88bb1..d0c996a1 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_miaobyte_basic.cu @@ -406,6 +406,7 @@ namespace deepx::tensorfunc template void launch_invert(const int32_t *a, int32_t *c, const int size); template void launch_invert(const int16_t *a, int16_t *c, const int size); template void launch_invert(const int8_t *a, int8_t *c, const int size); + template void launch_invert(const bool *a, bool *c, const int size); } diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/tensorlife_miaobyte.hpp b/excuter/op-mem-cuda/src/deepx/tensorfunc/tensorlife_miaobyte.hpp index 7334301a..70dd3f07 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/tensorlife_miaobyte.hpp +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/tensorlife_miaobyte.hpp @@ -19,8 +19,8 @@ namespace deepx::tensorfunc T *data; cudaError_t err = cudaMalloc(&data, size * sizeof(T)); if (err != cudaSuccess) - { - throw std::runtime_error("Failed to allocate Unified Memory"); + { + throwcudaerror("Failed to cudaMalloc "+std::to_string(size) +" "+ precision_str(precision()),err); } return data; } diff --git a/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp b/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp index e63cc74b..d3659af3 100644 --- a/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp +++ b/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp @@ -1026,6 +1026,9 @@ namespace deepx::tf case Precision::Int8: tensorfunc::invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Bool: + tensorfunc::invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; default: error = "Unsupported dtype: " + precision_str(a_type); return 1; diff --git a/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp b/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp index 77a4b8b1..982a51ac 100644 --- a/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp +++ b/excuter/op-mem-cuda/src/deepx/tf/elementwise_compare.hpp @@ -636,7 +636,7 @@ namespace deepx::tf { Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; - if (a_type != mask_type || mask_type != Precision::Bool) + if (mask_type != Precision::Bool) { error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(mask_type); return 1; @@ -769,7 +769,7 @@ namespace deepx::tf { Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; - if (a_type != mask_type || mask_type != Precision::Bool) + if (mask_type != Precision::Bool) { error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(mask_type); return 1; @@ -916,7 +916,7 @@ namespace deepx::tf } else { - tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::Switch(mem->gettensors(this->getvector(0)), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); } break; case Precision::Bool: diff --git a/excuter/op-mem-ompsimd/src/client/tfs.cpp b/excuter/op-mem-ompsimd/src/client/tfs.cpp index 8ebb1d95..b76faac8 100644 --- a/excuter/op-mem-ompsimd/src/client/tfs.cpp +++ b/excuter/op-mem-ompsimd/src/client/tfs.cpp @@ -299,11 +299,11 @@ namespace deepx::tf // invert author=miaobyte tffactory.add_tf(std::make_shared>(vector( { - Param("A", DataCategory::Tensor, Precision::Int64 | Precision::Int32 | Precision::Int16 | Precision::Int8), + Param("A", DataCategory::Tensor, Precision::Int64 | Precision::Int32 | Precision::Int16 | Precision::Int8|Precision::Bool), }), vector( { - Param("C", DataCategory::Tensor, Precision::Int64 | Precision::Int32 | Precision::Int16 | Precision::Int8), + Param("C", DataCategory::Tensor, Precision::Int64 | Precision::Int32 | Precision::Int16 | Precision::Int8|Precision::Bool), }))); // sqrt author=miaobyte tffactory.add_tf(std::make_shared>(vector( @@ -364,6 +364,33 @@ namespace deepx::tf { Param("C", DataCategory::Tensor, Precision::Any), }))); + // sin author=miaobyte + tffactory.add_tf(std::make_shared>(vector( + { + Param("A", DataCategory::Tensor, Precision::Any), + }), + vector( + { + Param("C", DataCategory::Tensor, Precision::Any), + }))); + // cos author=miaobyte + tffactory.add_tf(std::make_shared>(vector( + { + Param("A", DataCategory::Tensor, Precision::Any), + }), + vector( + { + Param("C", DataCategory::Tensor, Precision::Any), + }))); + // tan author=miaobyte + tffactory.add_tf(std::make_shared>(vector( + { + Param("A", DataCategory::Tensor, Precision::Any), + }), + vector( + { + Param("C", DataCategory::Tensor, Precision::Any), + }))); // max author=miaobyte tffactory.add_tf(std::make_shared>(vector( { diff --git a/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp b/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp index f6f3831a..517cbd9e 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise_miaobyte.hpp @@ -491,30 +491,11 @@ namespace deepx::tensorfunc { output.shape.rangeElementwiseParallel([&input, &output](int i, int i_end) { - const ScalableTag tag; - const size_t lanes = Lanes(tag); - size_t j=0; - - // 1. 处理前置未对齐部分 - while (j < i_end && !IsAligned(tag,input.data + i + j)) { - output.data[i+j] = std::sin(input.data[i+j]); - ++j; - } - - // 2. 处理中间对齐部分 - size_t aligned_end=i_end-(i_end%lanes); - for (; j+lanes<=aligned_end; j += lanes ) - { - auto vec = Load(tag, input.data + i + j); - auto vec_result = Sin(vec); - Store(vec_result, tag, output.data + i + j); - } - - // 3. 处理尾部剩余元素 - for (;j tag; - const size_t lanes = Lanes(tag); - size_t j=0; - - // 1. 处理前置未对齐部分 - while (j < i_end && !IsAligned(tag,input.data + i + j)) { - output.data[i+j] = std::cos(input.data[i+j]); - ++j; - } - - // 2. 处理中间对齐部分 - size_t aligned_end=i_end-(i_end%lanes); - for (; j+lanes<=aligned_end; j += lanes ) - { - auto vec = Load(tag, input.data + i + j); - auto vec_result = Cos(vec); - Store(vec_result, tag, output.data + i + j); - } - - // 3. 处理尾部剩余元素 - for (;j tag; - const size_t lanes = Lanes(tag); - size_t j=0; - - // 1. 处理前置未对齐部分 - while (j < i_end && !IsAligned(tag,input.data + i + j)) { - output.data[i+j] = std::tan(input.data[i+j]); - ++j; - } - - // 2. 处理中间对齐部分 - size_t aligned_end=i_end-(i_end%lanes); - for (; j+lanes<=aligned_end; j += lanes ) - { - auto vec = Load(tag, input.data + i + j); - auto vec_result = Tan(vec); - Store(vec_result, tag, output.data + i + j); - } - - // 3. 处理尾部剩余元素 - for (;j struct maxDispatcher @@ -784,16 +728,17 @@ namespace deepx::tensorfunc { A.shape.rangeElementwiseParallel([&A, &B, &mask, epsilon](int i, int i_end) { - for (int j = 0; j < i_end; j++) - { - if (epsilon == 0) - { - mask.data[i+j]=A.data[i+j]==B.data[i+j]; - } - else{ - mask.data[i+j]=std::abs(A.data[i+j]-B.data[i+j])<=epsilon; - } - } }); + for (int j = 0; j < i_end; j++) + { + if (epsilon == 0) + { + mask.data[i + j] = A.data[i + j] == B.data[i + j]; + } + else + { + mask.data[i + j] = std::abs(A.data[i + j] - B.data[i + j]) <= epsilon; + } + } }); } else { @@ -995,7 +940,7 @@ namespace deepx::tensorfunc { for (int j = 0; j < i_end; j++) { - int which_tensor=cases.data[i+j]; + casesT which_tensor=cases.data[i+j]; C.data[i+j]=tensors[which_tensor]->data[i+j]; } }); } diff --git a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp index cc3117b9..4cddae4d 100644 --- a/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp +++ b/excuter/op-mem-ompsimd/src/deepx/tf/elementwise.hpp @@ -550,6 +550,10 @@ namespace deepx::tf case Precision::Int8: tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Bool: + //TODO 暂时用int8,来计算bool类型 + tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; default: error = "Unsupported dtype: " + precision_str(a_type); return 1; @@ -843,6 +847,9 @@ namespace deepx::tf case Precision::Int8: tensorfunc::invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Bool: + tensorfunc::invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + break; default: error = "Unsupported dtype: " + precision_str(a_type); return 1; @@ -1817,11 +1824,11 @@ namespace deepx::tf { Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; - if (a_type != mask_type) + if (mask_type!=Precision::Bool) { - error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(mask_type); + error = "mask_type should be " +precision_str(Precision::Bool); return 1; - } + } switch (a_type) { case Precision::Float64: @@ -1934,9 +1941,9 @@ namespace deepx::tf { Precision a_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; Precision mask_type = mem->gettensor(this->returns[0].textvalue).get()->shape.dtype; - if (a_type != mask_type) + if (mask_type!=Precision::Bool) { - error = "Type mismatch: " + precision_str(a_type) + " != " + precision_str(mask_type); + error = "mask_type should be " +precision_str(Precision::Bool); return 1; } switch (a_type) diff --git a/front/py/deepx/nn/functional/authormap.py b/front/py/deepx/nn/functional/authormap.py index 834bda0a..9972f304 100644 --- a/front/py/deepx/nn/functional/authormap.py +++ b/front/py/deepx/nn/functional/authormap.py @@ -42,6 +42,10 @@ 'rpowscalar':'miaobyte', 'sqrt':'miaobyte', # + 'sin':'miaobyte', + 'cos':'miaobyte', + 'tan':'miaobyte', + # 'dropout':'miaobyte', #changeshape 'reshape':'miaobyte', diff --git a/front/py/deepx/nn/functional/leaffunc_changeshape.py b/front/py/deepx/nn/functional/leaffunc_changeshape.py index a830307a..a69f92d6 100644 --- a/front/py/deepx/nn/functional/leaffunc_changeshape.py +++ b/front/py/deepx/nn/functional/leaffunc_changeshape.py @@ -34,7 +34,7 @@ def permute(t:Tensor, outtensor=out if isinstance(out,str) or out is None: outshape = [t.shape[dim] for dim in dimorder] - outtensor=newtensor(outshape,dtype=t.dtype,name=out) + outtensor=newtensor(tuple(outshape),dtype=t.dtype,name=out) from .rtf_changeshape import rtf_transpose rtf_transpose(t,dimorder,outtensor,defaultauthor['transpose']) @@ -52,12 +52,11 @@ def concat(tensors:Union[list[Tensor],tuple[Tensor,...]],dim:int,out:Union[Tenso assert isinstance(tensors,list) or isinstance(tensors,tuple) for t in tensors: assert isinstance(t,Tensor) - + dim=dim%tensors[0].ndim outtensor=out if isinstance(out,str) or out is None: - outshape=list(tensors[0].shape) - outshape[dim]=sum(t.shape[dim] for t in tensors) - outtensor=newtensor(outshape,dtype=tensors[0].dtype,name=out) + outshape=Shape.concat(tuple(t.shape for t in tensors),dim) + outtensor=newtensor(tuple(outshape),dtype=tensors[0].dtype,name=out) from .rtf_changeshape import rtf_concat rtf_concat(tensors,dim,outtensor,defaultauthor['concat']) return outtensor diff --git a/front/py/deepx/scheduler/client/udpconn.py b/front/py/deepx/scheduler/client/udpconn.py index 6a12c26a..a25b0963 100644 --- a/front/py/deepx/scheduler/client/udpconn.py +++ b/front/py/deepx/scheduler/client/udpconn.py @@ -3,7 +3,7 @@ import select class UDPConn: - def __init__(self, endpoint: str = "localhost:9090"): + def __init__(self, endpoint: str = "localhost:8080"): # 解析endpoint self._host, port_str = endpoint.split(':') self._port = int(port_str) diff --git a/front/py/deepx/tensor/changeshape.py b/front/py/deepx/tensor/changeshape.py index e00a70d4..f359c2de 100644 --- a/front/py/deepx/tensor/changeshape.py +++ b/front/py/deepx/tensor/changeshape.py @@ -48,6 +48,12 @@ def broadcastTo(self,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor: result=broadcastTo_func(self,shape,out) return result +@tensor_method +def broadcast_to(self,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor: + from deepx.nn.functional import broadcastTo as broadcast_to_func + result=broadcast_to_func(self,shape,out) + return result + @tensor_method def indexselect(self,index:Tensor,axis:int=0,out:Union[Tensor,str]='')->Tensor: assert isinstance(index,Tensor) diff --git a/front/py/deepx/tensor/elementwise.py b/front/py/deepx/tensor/elementwise.py index 25de7456..acf587da 100644 --- a/front/py/deepx/tensor/elementwise.py +++ b/front/py/deepx/tensor/elementwise.py @@ -215,6 +215,10 @@ def switch_(self,cases:Union[Tensor,float,int]): # 类型转换 @tensor_method -def todtype(self,dest:Union[Tensor,float,int]): +def todtype(self,dtype:str): + from deepx.nn.functional import newtensor + dest=newtensor(self.shape,dtype=dtype) from deepx.nn.functional import todtype as todtype_func - return todtype_func(self,dest) \ No newline at end of file + todtype_func(self,dest) + return dest + \ No newline at end of file diff --git a/front/py/deepx/tensor/shape.py b/front/py/deepx/tensor/shape.py index 1098586d..cb0d3c8e 100644 --- a/front/py/deepx/tensor/shape.py +++ b/front/py/deepx/tensor/shape.py @@ -115,6 +115,19 @@ def transpose(cls,shape:tuple[int,...],dimorder:tuple[int,...]=None)->tuple[int, dimorder=tuple(range(len(shape))) return Shape(tuple(shape[i] for i in dimorder)) + @classmethod + def concat(cls,shapes:tuple,dim:int)->tuple[int,...]: + assert isinstance(shapes,tuple) + assert isinstance(dim,int) + dim=dim%len(shapes[0]) + for shape in shapes: + assert isinstance(shape,tuple) + assert len(shape)==len(shapes[0]) + outshape=list(shapes[0]) + for i in range(1,len(shapes)): + outshape[dim]+=shapes[i][dim] + return tuple(outshape) + @classmethod def matmul(cls,shape:tuple[int],other:tuple[int])->tuple[int]: if len(shape)<2 or len(other)<2: diff --git a/front/py/deepx/tensor/tensor.py b/front/py/deepx/tensor/tensor.py index 18bd0caa..ba40d678 100644 --- a/front/py/deepx/tensor/tensor.py +++ b/front/py/deepx/tensor/tensor.py @@ -119,7 +119,8 @@ def __radd__(self, other:Union[Number,'Tensor']): def __sub__(self, other:Union[Number,'Tensor']): return self.sub(other) def __rsub__(self, other:Union[Number,'Tensor']): - return self.sub(other) + x=self.mul(-1) + return x.add(other) def __mul__(self, other:Union[Number,'Tensor']): return self.mul(other) def __rmul__(self, other:Union[Number,'Tensor']): @@ -156,7 +157,7 @@ def __matmul__(self, other:'Tensor'): return self.matmul(other) def __rmatmul__(self, other:'Tensor'): return other.matmul(self) - #gather + def __getitem__(self, index:'Tensor'): return self.indexselect(index) diff --git a/front/py/deepx/transformer/modeling_rope_utils.py b/front/py/deepx/transformer/modeling_rope_utils.py index 3ee5255b..a0c7d323 100644 --- a/front/py/deepx/transformer/modeling_rope_utils.py +++ b/front/py/deepx/transformer/modeling_rope_utils.py @@ -27,25 +27,27 @@ def _compute_llama3_parameters(config:dict={ # Gets the default RoPE parameters inv_freq, attention_factor = _compute_default_rope_parameters(config) + factor = config["rope_scaling"]["factor"] # `8` in the original implementation low_freq_factor = config["rope_scaling"]["low_freq_factor"] # `1` in the original implementation high_freq_factor = config["rope_scaling"]["high_freq_factor"] # `4` in the original implementation old_context_len = config["rope_scaling"]["original_max_position_embeddings"] # `8192` in the original implementation - low_freq_wavelen = old_context_len /low_freq_factor - high_freq_wavelen = old_context_len/ high_freq_factor + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor wavelen = 2 * math.pi / inv_freq - wavelen.print() + # wavelen < high_freq_wavelen: do nothing # wavelen > low_freq_wavelen: divide by factor - inv_freq_llama = where(wavelen > low_freq_wavelen, inv_freq / config.factor, inv_freq) + inv_freq_llama = where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (config.old_context_len / wavelen - config.low_freq_factor) / (config.high_freq_factor - config.low_freq_factor) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / config.factor + smooth_factor * inv_freq_llama + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) inv_freq_llama = where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) return inv_freq_llama, attention_factor - + ROPE_INIT_FUNCTIONS = { "default": _compute_default_rope_parameters, # "linear": _compute_linear_scaling_rope_parameters, diff --git a/front/py/deepx/transformer/models/llama/embedding.py b/front/py/deepx/transformer/models/llama/embedding.py index 62e00a57..fbccc0a2 100644 --- a/front/py/deepx/transformer/models/llama/embedding.py +++ b/front/py/deepx/transformer/models/llama/embedding.py @@ -15,10 +15,7 @@ def __init__(self,config:dict): # 旋转初始化函数 self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] # 旋转初始化函数 - inv_freq, self.attention_scaling = self.rope_init_fn(config) - # 注册缓存 - self.register_buffer("inv_freq", inv_freq, persistent=False) - # 原始旋转频率 + self.inv_freq, self.attention_scaling = self.rope_init_fn(config) self.original_inv_freq = self.inv_freq # def _dynamic_frequency_update(self, position_ids, device): @@ -42,14 +39,14 @@ def __init__(self,config:dict): def forward(self, x, position_ids): # 扩展旋转频率 - inv_freq_expanded = self.inv_freq.unsqueeze(dim=0).unsqueeze(dim=2).float() + inv_freq_expanded = self.inv_freq.unsqueeze(dim=0).unsqueeze(dim=2).todtype('float32') broadcast_shape=(position_ids.shape[0], self.inv_freq.shape[0], 1) - inv_freq_expanded = inv_freq_expanded.broadcast_to(broadcast_shape) + inv_freq_expanded = inv_freq_expanded.reshape(broadcast_shape) # 使用torch.unsqueeze和type转换替代索引操作 - position_ids_expanded = position_ids.unsqueeze(dim=1).to(dtype=x.dtype) + position_ids_expanded = position_ids.unsqueeze(dim=1).todtype(x.dtype) # 计算频率 - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + freqs = (inv_freq_expanded @ position_ids_expanded).T # 拼接频率 emb = concat((freqs, freqs), dim=-1) # 计算余弦和正弦 @@ -59,4 +56,4 @@ def forward(self, x, position_ids): cos = cos * self.attention_scaling sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return cos.todtype(x.dtype), sin.todtype(x.dtype) diff --git a/front/py/examples/1_tensor/2_newbig.py b/front/py/examples/1_tensor/2_newbig.py new file mode 100644 index 00000000..2d0b4bbf --- /dev/null +++ b/front/py/examples/1_tensor/2_newbig.py @@ -0,0 +1,18 @@ +import sys +sys.path.append('/home/lipeng/code/git.array2d.com/ai/deepx/front/py') # 将项目根目录添加到Python路径 + +def newtensor(dtype): + from deepx.nn.functional import newtensor + for i in range(0,20): + t=newtensor((1,20,4096),dtype=dtype) + # t.print() + + +if __name__ == "__main__": + args=sys.argv[1:] + if len(args)==0: + newtensor('float32') + elif len(args)==1: + newtensor(args[0]) + else: + print("Usage: python 1_new.py [dtype]") \ No newline at end of file diff --git a/front/py/examples/4_transformer/llama/llama_rope.py b/front/py/examples/4_transformer/llama/llama_rope.py new file mode 100644 index 00000000..53c8995a --- /dev/null +++ b/front/py/examples/4_transformer/llama/llama_rope.py @@ -0,0 +1,30 @@ +from .llama_rope_torch import dir,config + +############-------DEEPX-------################ +from deepx.nn.modules import Embedding,Module +from deepx import load,arange +from deepx.transformer.models.llama import LlamaRotaryEmbedding + +input=load(dir+'input') + +embed_tokens_weight=load(dir+'weight') + +class NetDeepx(Module): + def __init__(self,configdict:dict): + super().__init__() + self.embed_tokens = Embedding(configdict["vocab_size"], configdict["hidden_size"],weight=embed_tokens_weight) + self.rotary_emb = LlamaRotaryEmbedding(config=configdict) + + def forward(self,x): + inputs_embeds = self.embed_tokens(x) + hidden_states = inputs_embeds + position_ids = arange(start=0,end=hidden_states.shape[1]).unsqueeze(0) + return self.rotary_emb(hidden_states, position_ids) + +if __name__ == "__main__": + net = NetDeepx(configdict=config.to_dict()) + out=net.forward(input) + out[0].print() + out[1].print() + + diff --git a/front/py/examples/4_transformer/llama/1_llama_rope.py b/front/py/examples/4_transformer/llama/llama_rope_torch.py similarity index 52% rename from front/py/examples/4_transformer/llama/1_llama_rope.py rename to front/py/examples/4_transformer/llama/llama_rope_torch.py index 2738a41f..4e9301cc 100644 --- a/front/py/examples/4_transformer/llama/1_llama_rope.py +++ b/front/py/examples/4_transformer/llama/llama_rope_torch.py @@ -1,94 +1,76 @@ hidden_size = 8 eps = 1e-6 -dir='/home/lipeng/model/deepxmodel/llama/' -model_path="/home/lipeng/model/deepseek-ai/DeepSeek-R1-Distill-Llama-8B" +dir = '/home/lipeng/model/deepxmodel/llama/' +model_path = "/home/lipeng/model/deepseek-ai/DeepSeek-R1-Distill-Llama-8B" print() -from transformers import AutoTokenizer,AutoConfig +from transformers import AutoTokenizer, AutoConfig + + def init_tokenizer(model_path): tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token return tokenizer + tokenizer = init_tokenizer(model_path) -config=AutoConfig.from_pretrained(model_path) +config = AutoConfig.from_pretrained(model_path) + + def tokenize_text(text, tokenizer): tokens = tokenizer(text, return_tensors="pt").input_ids import torch # 处理超出词汇表范围的token if torch.any(tokens >= tokenizer.vocab_size): # 获取UNK token ID,如果没有则使用0 - unk_token_id = tokenizer.unk_token_id if hasattr(tokenizer, 'unk_token_id') and tokenizer.unk_token_id is not None else 0 + unk_token_id = tokenizer.unk_token_id if hasattr(tokenizer, + 'unk_token_id') and tokenizer.unk_token_id is not None else 0 # 替换所有超出范围的token为UNK tokens = torch.where(tokens < tokenizer.vocab_size, tokens, torch.tensor(unk_token_id, device=tokens.device)) return tokens - + + ############-------PyTorch-------################ -import torch +import torch # 创建输入 text = "这是一个测试文本,用于演示嵌入层的使用。" torch_input = tokenize_text(text, tokenizer) from deepxutil.torch import save_torch -save_torch(torch_input,dir+'input') + +save_torch(torch_input, dir + 'input') + # 创建网络 class NetTorch(torch.nn.Module): from transformers.models.llama.modeling_llama import LlamaConfig - def __init__(self,config:LlamaConfig): + def __init__(self, config: LlamaConfig): super().__init__() self.padding_idx = config.pad_token_id self.config = config self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding self.rotary_emb = LlamaRotaryEmbedding(config=config) - - def forward(self,x): + print("rotary_emb.inv_freq") + print(self.rotary_emb.inv_freq) + def forward(self, x): inputs_embeds = self.embed_tokens(x) + print(inputs_embeds) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) return self.rotary_emb(hidden_states, position_ids) - - -torch_net = NetTorch(config) -save_torch(torch_net.embed_tokens.weight,dir+'weight') -# 前向传播 -torch_output = torch_net(torch_input) -torch_sin, torch_cos = torch_output - -print("sin shape:",torch_sin.shape) -print("sin:", torch_sin) - -print("cos shape:", torch_cos.shape) -print("cos:", torch_cos) - - -############-------DEEPX-------################ -from deepx.nn.modules import Embedding,Module -from deepx import load -from deepx.transformer.models.llama import LlamaRotaryEmbedding - -input=load(dir+'input') - -embed_tokens_weight=load(dir+'weight') - -class NetDeepx(Module): - def __init__(self,configdict:dict): - super().__init__() - self.embed_tokens = Embedding(configdict["vocab_size"], configdict["hidden_size"],weight=embed_tokens_weight) - self.rotary_emb = LlamaRotaryEmbedding(config=configdict) - - def forward(self,x): - inputs_embeds = self.embed_tokens(x) - hidden_states = inputs_embeds - position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) - return self.rotary_emb(hidden_states, position_ids) - -net = NetDeepx(configdict=config.to_dict()) -out=net.forward(input) -out.print() +if __name__ == "__main__": + torch_net = NetTorch(config) + save_torch(torch_net.embed_tokens.weight, dir + 'weight') + # 前向传播 + torch_output = torch_net(torch_input) + torch_sin, torch_cos = torch_output + print("sin shape:", torch_sin.shape) + print("sin:", torch_sin) + print("cos shape:", torch_cos.shape) + print("cos:", torch_cos)