From 65a2c8bbdf786fe7ad3cb41338a191d69370fb27 Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Mon, 16 Jun 2025 23:06:25 +0800 Subject: [PATCH 1/5] =?UTF-8?q?repeat:=E6=9A=82=E5=81=9C=E5=BC=80=E5=8F=91?= =?UTF-8?q?repeat=5Finterleave?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/deepx/tensorfunc/changeshape.hpp | 18 +++++++++--------- .../deepx/tensorfunc/changeshape_miaobyte.cuh | 8 ++++++++ .../deepx/tensorfunc/changeshape_miaobyte.hpp | 1 + 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/excuter/cpp-common/src/deepx/tensorfunc/changeshape.hpp b/excuter/cpp-common/src/deepx/tensorfunc/changeshape.hpp index 11c4b2b..c0eb430 100644 --- a/excuter/cpp-common/src/deepx/tensorfunc/changeshape.hpp +++ b/excuter/cpp-common/src/deepx/tensorfunc/changeshape.hpp @@ -92,19 +92,19 @@ namespace deepx::tensorfunc template struct repeat_interleaveDispatcher { - static void repeat_interleave(const Tensor &A, const int repeats, Tensor &B) = delete; - static void repeat_interleave(const Tensor &A, const Tensor &repeats, Tensor &B) = delete; + static void repeat_interleave(const Tensor &A, const int repeats,const int dim, Tensor &B) = delete; + // static void repeat_interleave(const Tensor &A, const Tensor &repeats, Tensor &B) = delete; }; template - void repeat_interleave(const Tensor &A, const int repeats, Tensor &B) + void repeat_interleave(const Tensor &A, const int repeats,const int dim, Tensor &B) { - repeat_interleaveDispatcher::repeat_interleave(A, repeats, B); - } - template - void repeat_interleave(const Tensor &A, const Tensor &repeats, Tensor &B) - { - repeat_interleaveDispatcher::repeat_interleave(A, repeats, B); + repeat_interleaveDispatcher::repeat_interleave(A, repeats,dim, B); } + // template + // void repeat_interleave(const Tensor &A, const Tensor &repeats, Tensor &B) + // { + // repeat_interleaveDispatcher::repeat_interleave(A, repeats, B); + // } diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.cuh b/excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.cuh index d3845ee..5a88380 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.cuh +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.cuh @@ -81,5 +81,13 @@ namespace deepx::tensorfunc const int *repeats, T *output, const int *outputStrides, const int outputlen, const int dim); + + // repeat_interleave + template + __global__ void repeat_interleave_kernel( + const T *input, const int *inputStrides, + const int *repeats, + T *output, const int *outputStrides, const int outputlen, + const int dim); }; #endif // DEEPX_TENSORFUNC_CHANGESHAPE_MIAOBYTE_CUH \ No newline at end of file diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.hpp b/excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.hpp index 818e4ba..14f4a1c 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.hpp +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/changeshape_miaobyte.hpp @@ -152,5 +152,6 @@ namespace deepx::tensorfunc B.data, B.shape.strides.data(),B.shape.size, B.shape.dim()); } }; + } #endif // DEEPX_TENSORFUNC_CHANGESHAPE_MIAOBYTE_HPP \ No newline at end of file From 7b4522c986f927ecfa9487dae4dff5704cf32207 Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Wed, 18 Jun 2025 22:22:30 +0800 Subject: [PATCH 2/5] =?UTF-8?q?llama=5Frope:=E5=88=86=E7=A6=BB=E5=85=B1?= =?UTF-8?q?=E4=BA=AB=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- front/py/examples/4_transformer/llama/llama_ | 0 .../4_transformer/llama/llama_attention.py | 31 ++++++++++++++ .../llama/llama_attention_torch.py | 35 ++++++++++++++++ .../4_transformer/llama/llama_rope.py | 2 +- .../4_transformer/llama/llama_rope_torch.py | 42 +------------------ .../4_transformer/llama/token_text.py | 42 +++++++++++++++++++ 6 files changed, 110 insertions(+), 42 deletions(-) delete mode 100644 front/py/examples/4_transformer/llama/llama_ create mode 100644 front/py/examples/4_transformer/llama/llama_attention.py create mode 100644 front/py/examples/4_transformer/llama/llama_attention_torch.py create mode 100644 front/py/examples/4_transformer/llama/token_text.py diff --git a/front/py/examples/4_transformer/llama/llama_ b/front/py/examples/4_transformer/llama/llama_ deleted file mode 100644 index e69de29..0000000 diff --git a/front/py/examples/4_transformer/llama/llama_attention.py b/front/py/examples/4_transformer/llama/llama_attention.py new file mode 100644 index 0000000..4f40b39 --- /dev/null +++ b/front/py/examples/4_transformer/llama/llama_attention.py @@ -0,0 +1,31 @@ +from token_text import dir,config + +############-------DEEPX-------################ +from deepx.nn.modules import Embedding,Module +from deepx import load,arange +from deepx.transformer.models.llama import LlamaRotaryEmbedding + +input=load(dir+'input') + +embed_tokens_weight=load(dir+'weight') + +class NetDeepx(Module): + def __init__(self,configdict:dict): + super().__init__() + self.embed_tokens = Embedding(configdict["vocab_size"], configdict["hidden_size"],weight=embed_tokens_weight) + self.rotary_emb = LlamaRotaryEmbedding(config=configdict) + print("rotary_emb.inv_freq") + self.rotary_emb.inv_freq.print() + def forward(self,x): + inputs_embeds = self.embed_tokens(x) + hidden_states = inputs_embeds + position_ids = arange(start=0,end=hidden_states.shape[1]).unsqueeze(0) + return self.rotary_emb(hidden_states, position_ids) + +if __name__ == "__main__": + net = NetDeepx(configdict=config.to_dict()) + out=net.forward(input) + out[0].print() + out[1].print() + + diff --git a/front/py/examples/4_transformer/llama/llama_attention_torch.py b/front/py/examples/4_transformer/llama/llama_attention_torch.py new file mode 100644 index 0000000..dc0369f --- /dev/null +++ b/front/py/examples/4_transformer/llama/llama_attention_torch.py @@ -0,0 +1,35 @@ +############-------PyTorch-------################ +import torch +from token_text import torch_input,config + +# 创建网络 + +class NetTorch(torch.nn.Module): + from transformers.models.llama.modeling_llama import LlamaConfig + def __init__(self, config: LlamaConfig): + super().__init__() + self.padding_idx = config.pad_token_id + self.config = config + self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + self.rotary_emb = LlamaRotaryEmbedding(config=config) + print("rotary_emb.inv_freq") + print(self.rotary_emb.inv_freq) + def forward(self, x): + inputs_embeds = self.embed_tokens(x) + hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + return self.rotary_emb(hidden_states, position_ids) + +if __name__ == "__main__": + torch_net = NetTorch(config) + # 前向传播 + torch_output = torch_net(torch_input) + torch_sin, torch_cos = torch_output + + print("sin shape:", torch_sin.shape) + print("sin:", torch_sin) + + print("cos shape:", torch_cos.shape) + print("cos:", torch_cos) diff --git a/front/py/examples/4_transformer/llama/llama_rope.py b/front/py/examples/4_transformer/llama/llama_rope.py index e0f0598..4f40b39 100644 --- a/front/py/examples/4_transformer/llama/llama_rope.py +++ b/front/py/examples/4_transformer/llama/llama_rope.py @@ -1,4 +1,4 @@ -from llama_rope_torch import dir,config +from token_text import dir,config ############-------DEEPX-------################ from deepx.nn.modules import Embedding,Module diff --git a/front/py/examples/4_transformer/llama/llama_rope_torch.py b/front/py/examples/4_transformer/llama/llama_rope_torch.py index 3f894e3..dc0369f 100644 --- a/front/py/examples/4_transformer/llama/llama_rope_torch.py +++ b/front/py/examples/4_transformer/llama/llama_rope_torch.py @@ -1,45 +1,6 @@ -hidden_size = 8 -eps = 1e-6 -dir = '/home/lipeng/model/deepxmodel/llama/' -model_path = "/home/lipeng/model/deepseek-ai/DeepSeek-R1-Distill-Llama-8B" -print() - -from transformers import AutoTokenizer, AutoConfig - - -def init_tokenizer(model_path): - tokenizer = AutoTokenizer.from_pretrained(model_path) - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - -tokenizer = init_tokenizer(model_path) -config = AutoConfig.from_pretrained(model_path) - - -def tokenize_text(text, tokenizer): - tokens = tokenizer(text, return_tensors="pt").input_ids - import torch - # 处理超出词汇表范围的token - if torch.any(tokens >= tokenizer.vocab_size): - # 获取UNK token ID,如果没有则使用0 - unk_token_id = tokenizer.unk_token_id if hasattr(tokenizer, - 'unk_token_id') and tokenizer.unk_token_id is not None else 0 - # 替换所有超出范围的token为UNK - tokens = torch.where(tokens < tokenizer.vocab_size, tokens, torch.tensor(unk_token_id, device=tokens.device)) - return tokens - - ############-------PyTorch-------################ import torch - -# 创建输入 -text = "这是一个测试文本,用于演示嵌入层的使用。" -torch_input = tokenize_text(text, tokenizer) -from deepxutil.torch import save_torch - -save_torch(torch_input, dir + 'input') - +from token_text import torch_input,config # 创建网络 @@ -63,7 +24,6 @@ def forward(self, x): if __name__ == "__main__": torch_net = NetTorch(config) - save_torch(torch_net.embed_tokens.weight, dir + 'weight') # 前向传播 torch_output = torch_net(torch_input) torch_sin, torch_cos = torch_output diff --git a/front/py/examples/4_transformer/llama/token_text.py b/front/py/examples/4_transformer/llama/token_text.py new file mode 100644 index 0000000..84a4a59 --- /dev/null +++ b/front/py/examples/4_transformer/llama/token_text.py @@ -0,0 +1,42 @@ +hidden_size = 8 +eps = 1e-6 +dir = '/home/lipeng/model/deepxmodel/llama/' +model_path = "/home/lipeng/model/deepseek-ai/DeepSeek-R1-Distill-Llama-8B" +print() + +from transformers import AutoTokenizer, AutoConfig + + +def init_tokenizer(model_path): + tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +tokenizer = init_tokenizer(model_path) +config = AutoConfig.from_pretrained(model_path) + + +def tokenize_text(text, tokenizer): + tokens = tokenizer(text, return_tensors="pt").input_ids + import torch + # 处理超出词汇表范围的token + if torch.any(tokens >= tokenizer.vocab_size): + # 获取UNK token ID,如果没有则使用0 + unk_token_id = tokenizer.unk_token_id if hasattr(tokenizer, + 'unk_token_id') and tokenizer.unk_token_id is not None else 0 + # 替换所有超出范围的token为UNK + tokens = torch.where(tokens < tokenizer.vocab_size, tokens, torch.tensor(unk_token_id, device=tokens.device)) + return tokens + + +############-------PyTorch-------################ +import torch + +# 创建输入 +text = "这是一个测试文本,用于演示嵌入层的使用。" +torch_input = tokenize_text(text, tokenizer) +from deepxutil.torch import save_torch + +save_torch(torch_input, dir + 'input') + From cac35fdbcb40c03a14c619300c914c7813caf7e5 Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Wed, 18 Jun 2025 23:24:05 +0800 Subject: [PATCH 3/5] =?UTF-8?q?cuda:=E4=BF=AE=E5=A4=8Dint64=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E7=9A=84=E8=B0=83=E7=94=A8=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/deepx/tf/elementwise_basic.hpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp b/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp index fca5298..678810b 100644 --- a/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp +++ b/excuter/op-mem-cuda/src/deepx/tf/elementwise_basic.hpp @@ -411,7 +411,7 @@ namespace deepx::tf tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); @@ -479,7 +479,7 @@ namespace deepx::tf tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); @@ -548,7 +548,7 @@ namespace deepx::tf tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); @@ -616,7 +616,7 @@ namespace deepx::tf tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); @@ -685,7 +685,7 @@ namespace deepx::tf tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::rsubscalar(this->getvar(1, mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); @@ -754,7 +754,7 @@ namespace deepx::tf tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); @@ -822,7 +822,7 @@ namespace deepx::tf tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); @@ -891,7 +891,7 @@ namespace deepx::tf tensorfunc::div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); @@ -959,7 +959,7 @@ namespace deepx::tf tensorfunc::divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1, mem), *mem->gettensor(this->returns[0].textvalue)); @@ -1027,7 +1027,7 @@ namespace deepx::tf tensorfunc::rdivscalar(this->getvar(0, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int64: - tensorfunc::rdivscalar(this->getvar(0, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); + tensorfunc::rdivscalar(this->getvar(0, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; case Precision::Int32: tensorfunc::rdivscalar(this->getvar(0, mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); From b8345ab7bd83a069f033785259aba1e82b167844 Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Wed, 18 Jun 2025 23:24:22 +0800 Subject: [PATCH 4/5] =?UTF-8?q?py:rotate=5Fhalf=20=E9=AA=8C=E8=AF=81?= =?UTF-8?q?=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- front/py/deepx/tensor/changeshape.py | 4 +-- front/py/deepx/tensor/tensor.py | 2 ++ .../transformer/models/llama/__init__.py | 5 +++- .../transformer/models/llama/attention.py | 6 ++--- .../4_transformer/llama/llama_attention.py | 27 +++---------------- 5 files changed, 15 insertions(+), 29 deletions(-) diff --git a/front/py/deepx/tensor/changeshape.py b/front/py/deepx/tensor/changeshape.py index 39b415a..6b0b309 100644 --- a/front/py/deepx/tensor/changeshape.py +++ b/front/py/deepx/tensor/changeshape.py @@ -55,9 +55,9 @@ def broadcast_to(self,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor: return result @tensor_method -def indexselect(self,index:Tensor,axis:int=0,out:Union[Tensor,str]='')->Tensor: +def indexselect(self,index:Tensor,gatheraxis:int=0,out:Union[Tensor,str]='')->Tensor: assert isinstance(index,Tensor) - gatheraxis=axis%self.ndim + gatheraxis=gatheraxis%self.ndim from deepx.nn.functional import indexselect as indexselect_func result=indexselect_func(self,index,gatheraxis,out) return result diff --git a/front/py/deepx/tensor/tensor.py b/front/py/deepx/tensor/tensor.py index d7dbd78..ac63d9a 100644 --- a/front/py/deepx/tensor/tensor.py +++ b/front/py/deepx/tensor/tensor.py @@ -124,6 +124,8 @@ def __mul__(self, other:Union[Number,'Tensor']): return self.mul(other) def __rmul__(self, other:Union[Number,'Tensor']): return self.mul(other) + def __neg__(self): + return self.mul(-1.0) def __truediv__(self, other:Union[Number,'Tensor']): return self.div(other) def __rtruediv__(self, other:Union[Number,'Tensor']): diff --git a/front/py/deepx/transformer/models/llama/__init__.py b/front/py/deepx/transformer/models/llama/__init__.py index d77def3..96a73bf 100644 --- a/front/py/deepx/transformer/models/llama/__init__.py +++ b/front/py/deepx/transformer/models/llama/__init__.py @@ -1,4 +1,7 @@ from .embedding import * +from .attention import * + __all__ = [ - "LlamaRotaryEmbedding" + "LlamaRotaryEmbedding", + "rotate_half" ] \ No newline at end of file diff --git a/front/py/deepx/transformer/models/llama/attention.py b/front/py/deepx/transformer/models/llama/attention.py index 325d6a2..eb37f73 100644 --- a/front/py/deepx/transformer/models/llama/attention.py +++ b/front/py/deepx/transformer/models/llama/attention.py @@ -7,9 +7,9 @@ def rotate_half(x:Tensor): index_front=arange(0,x.shape[-1]//2,dtype="int32") index_back=arange(x.shape[-1]//2,x.shape[-1],dtype="int32") - x1 = x.index_select(dim=-1,index=index_front) - x2 = x.index_select(dim=-1,index=index_back) - return concat((-x2, x1), dim=-1) + x1 = x.indexselect(gatheraxis=-1,index=index_front) + x2 = x.indexselect(gatheraxis=-1,index=index_back) + return concat((-x2, x1,), dim=-1) def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1): cos = cos.unsqueeze(unsqueeze_dim) diff --git a/front/py/examples/4_transformer/llama/llama_attention.py b/front/py/examples/4_transformer/llama/llama_attention.py index 4f40b39..2ecec0b 100644 --- a/front/py/examples/4_transformer/llama/llama_attention.py +++ b/front/py/examples/4_transformer/llama/llama_attention.py @@ -3,29 +3,10 @@ ############-------DEEPX-------################ from deepx.nn.modules import Embedding,Module from deepx import load,arange -from deepx.transformer.models.llama import LlamaRotaryEmbedding +from deepx.transformer.models.llama import rotate_half input=load(dir+'input') - -embed_tokens_weight=load(dir+'weight') - -class NetDeepx(Module): - def __init__(self,configdict:dict): - super().__init__() - self.embed_tokens = Embedding(configdict["vocab_size"], configdict["hidden_size"],weight=embed_tokens_weight) - self.rotary_emb = LlamaRotaryEmbedding(config=configdict) - print("rotary_emb.inv_freq") - self.rotary_emb.inv_freq.print() - def forward(self,x): - inputs_embeds = self.embed_tokens(x) - hidden_states = inputs_embeds - position_ids = arange(start=0,end=hidden_states.shape[1]).unsqueeze(0) - return self.rotary_emb(hidden_states, position_ids) - -if __name__ == "__main__": - net = NetDeepx(configdict=config.to_dict()) - out=net.forward(input) - out[0].print() - out[1].print() - +input.print() +r=rotate_half(input) +r.print() From 7e4533f0f9b1f2593d308c1749b2d0178cbe2aa9 Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Wed, 18 Jun 2025 23:37:53 +0800 Subject: [PATCH 5/5] attention:rotatehalf --- .../4_transformer/llama/llama_attention.py | 5 +-- .../llama/llama_attention_torch.py | 43 ++++--------------- 2 files changed, 10 insertions(+), 38 deletions(-) diff --git a/front/py/examples/4_transformer/llama/llama_attention.py b/front/py/examples/4_transformer/llama/llama_attention.py index 2ecec0b..c029a74 100644 --- a/front/py/examples/4_transformer/llama/llama_attention.py +++ b/front/py/examples/4_transformer/llama/llama_attention.py @@ -1,8 +1,7 @@ -from token_text import dir,config +from token_text import dir ############-------DEEPX-------################ -from deepx.nn.modules import Embedding,Module -from deepx import load,arange +from deepx import load from deepx.transformer.models.llama import rotate_half input=load(dir+'input') diff --git a/front/py/examples/4_transformer/llama/llama_attention_torch.py b/front/py/examples/4_transformer/llama/llama_attention_torch.py index dc0369f..3cb8aca 100644 --- a/front/py/examples/4_transformer/llama/llama_attention_torch.py +++ b/front/py/examples/4_transformer/llama/llama_attention_torch.py @@ -1,35 +1,8 @@ -############-------PyTorch-------################ -import torch -from token_text import torch_input,config - -# 创建网络 - -class NetTorch(torch.nn.Module): - from transformers.models.llama.modeling_llama import LlamaConfig - def __init__(self, config: LlamaConfig): - super().__init__() - self.padding_idx = config.pad_token_id - self.config = config - self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding - self.rotary_emb = LlamaRotaryEmbedding(config=config) - print("rotary_emb.inv_freq") - print(self.rotary_emb.inv_freq) - def forward(self, x): - inputs_embeds = self.embed_tokens(x) - hidden_states = inputs_embeds - # create position embeddings to be shared across the decoder layers - position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) - return self.rotary_emb(hidden_states, position_ids) - -if __name__ == "__main__": - torch_net = NetTorch(config) - # 前向传播 - torch_output = torch_net(torch_input) - torch_sin, torch_cos = torch_output - - print("sin shape:", torch_sin.shape) - print("sin:", torch_sin) - - print("cos shape:", torch_cos.shape) - print("cos:", torch_cos) +from token_text import torch_input +print() +############-------TORCH-------################ +from transformers.models.llama.modeling_llama import rotate_half + +print(torch_input) +r=rotate_half(torch_input) +print(r)