From 9cb394c15d81ebcb67e65280781040830c60a65d Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Fri, 20 Jun 2025 00:29:00 +0800 Subject: [PATCH 1/2] =?UTF-8?q?tensor:=5F=5Fgetitem=5F=5F=20=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- front/py/deepx/nn/functional/leaffunc_life.py | 2 +- front/py/deepx/tensor/tensor.py | 31 ++++++++++- front/py/examples/1_tensor/getitem.py | 14 ----- front/py/examples/1_tensor/getitem/getitem.py | 10 ++++ .../1_tensor/getitem/getitem_torch.py | 5 ++ front/py/examples/1_tensor/getitem/tensor.py | 53 +++++++++++++++++++ 6 files changed, 98 insertions(+), 17 deletions(-) delete mode 100644 front/py/examples/1_tensor/getitem.py create mode 100644 front/py/examples/1_tensor/getitem/getitem.py create mode 100644 front/py/examples/1_tensor/getitem/getitem_torch.py create mode 100644 front/py/examples/1_tensor/getitem/tensor.py diff --git a/front/py/deepx/nn/functional/leaffunc_life.py b/front/py/deepx/nn/functional/leaffunc_life.py index 78d438f1..72965d9c 100644 --- a/front/py/deepx/nn/functional/leaffunc_life.py +++ b/front/py/deepx/nn/functional/leaffunc_life.py @@ -1,7 +1,7 @@ from deepx.tensor import Tensor from typing import Union -def newtensor(shape:tuple[int,...],dtype:str='float32',name:str=None): +def newtensor(shape:tuple[int,...],dtype:str='float32',name:str=None)->Tensor: assert isinstance(shape,tuple) for i in shape: assert isinstance(i,int) diff --git a/front/py/deepx/tensor/tensor.py b/front/py/deepx/tensor/tensor.py index ac63d9ac..bad202a2 100644 --- a/front/py/deepx/tensor/tensor.py +++ b/front/py/deepx/tensor/tensor.py @@ -159,8 +159,35 @@ def __matmul__(self, other:'Tensor'): def __rmatmul__(self, other:'Tensor'): return other.matmul(self) - def __getitem__(self, index:'Tensor'): - return self.indexselect(index) + def __getitem__(self, idx): + if isinstance(idx,Tensor): + return self.indexselect(idx) + if isinstance(idx, int): + from deepx.tensor import newtensor + index=newtensor((1,),dtype='int32') + index.full_(idx) + return self.indexselect(index) + elif isinstance(idx, tuple): + indices=list(idx) + else: + raise TypeError("Index must be an integer or a slice") + + if Ellipsis in indices: + ellipsis_idx = indices.index(Ellipsis) + num_ellipsis = self.ndim - (len(indices) - 1) + indices[ellipsis_idx:ellipsis_idx + 1] = [slice(None)] * num_ellipsis + + print(indices) + need_reshape=False + need_indexselect=False + for i, ix in enumerate(indices): + if ix is None: + need_reshape = True + elif isinstance(ix, slice): + if ix != slice(None, None, None): + need_indexselect = True + print(need_reshape,need_indexselect) + return self #shape操作 @property diff --git a/front/py/examples/1_tensor/getitem.py b/front/py/examples/1_tensor/getitem.py deleted file mode 100644 index c0746762..00000000 --- a/front/py/examples/1_tensor/getitem.py +++ /dev/null @@ -1,14 +0,0 @@ - -def deepx_getitem(): - from deepx import newtensor - t=newtensor((2,3,4)).full_(1) - t2=t[None, :, None] - t2.print() -def torch_getitem(): - import torch - t=torch.full((2,3,4),1) - t2=t[None, :, None] - print(t2) -if __name__ == "__main__": - deepx_getitem() - torch_getitem() \ No newline at end of file diff --git a/front/py/examples/1_tensor/getitem/getitem.py b/front/py/examples/1_tensor/getitem/getitem.py new file mode 100644 index 00000000..74ba95c2 --- /dev/null +++ b/front/py/examples/1_tensor/getitem/getitem.py @@ -0,0 +1,10 @@ +from deepx import newtensor,Tensor +t = newtensor((2, 3, 13)) +t.full_(1) +print() +t2 = t[None, :, None] +t2.print() +t3=t[:,None,:] +t3.print() +t4=t[..., : t.shape[-1] // 2] +t4.print() diff --git a/front/py/examples/1_tensor/getitem/getitem_torch.py b/front/py/examples/1_tensor/getitem/getitem_torch.py new file mode 100644 index 00000000..48108cd4 --- /dev/null +++ b/front/py/examples/1_tensor/getitem/getitem_torch.py @@ -0,0 +1,5 @@ +import torch + +t = torch.full((2, 3, 4), 1) +t2 = t[None, :, None] +print(t2) \ No newline at end of file diff --git a/front/py/examples/1_tensor/getitem/tensor.py b/front/py/examples/1_tensor/getitem/tensor.py new file mode 100644 index 00000000..e2806603 --- /dev/null +++ b/front/py/examples/1_tensor/getitem/tensor.py @@ -0,0 +1,53 @@ +class Tensor: + def __init__(self, data): + # 假设 data 是嵌套 list 或 numpy ndarray + import numpy as np + self.data = np.array(data) + + def __getitem__(self, idx): + # 支持 None, int, slice, tuple 等 + # 重点:遇到 None 时,插入新轴 + import numpy as np + + if not isinstance(idx, tuple): + idx = (idx,) + + # 统计原始索引和 None 的位置,组装成新的索引 + new_idx = [] + expand_axes = [] + for i, ix in enumerate(idx): + if ix is None: + expand_axes.append(len(new_idx)) + else: + new_idx.append(ix) + + # 先索引 + result = self.data[tuple(new_idx)] + # 再插入新轴 + for ax in expand_axes: + result = np.expand_dims(result, axis=ax) + + # 返回新 Tensor + ret = Tensor(result) + return ret + + @property + def shape(self): + return self.data.shape + + def __repr__(self): + return f"Tensor(shape={self.shape}, data=\n{self.data})" + + +# 测试代码 +t = Tensor([[1, 2, 3], [4, 5, 6]]) +print("原始shape:", t.shape) # (2, 3) + +t2 = t[None] +print("t[None].shape:", t2.shape) # (1, 2, 3) + +t3 = t[:, None] +print("t[:, None].shape:", t3.shape) # (2, 1, 3) + +t4 = t[None, :, None] +print("t[None, :, None].shape:", t4.shape) # (1, 2, 1, 3) \ No newline at end of file From 3ec568473ea16ee87088a65a20e40a0eaa5a20b9 Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Sun, 22 Jun 2025 20:13:39 +0800 Subject: [PATCH 2/2] py:tensor __getitem__ --- front/py/deepx/nn/functional/__init__.py | 2 +- front/py/deepx/nn/functional/changeshape.py | 20 +++++- .../nn/functional/leaffunc_changeshape.py | 9 ++- front/py/deepx/tensor/changeshape.py | 8 +++ front/py/deepx/tensor/tensor.py | 61 +++++++++++-------- .../transformer/models/llama/attention.py | 10 ++- front/py/examples/1_tensor/getitem/getitem.py | 11 ++-- .../1_tensor/getitem/getitem_torch.py | 10 ++- 8 files changed, 86 insertions(+), 45 deletions(-) diff --git a/front/py/deepx/nn/functional/__init__.py b/front/py/deepx/nn/functional/__init__.py index 3f5d33cf..deeafa5d 100644 --- a/front/py/deepx/nn/functional/__init__.py +++ b/front/py/deepx/nn/functional/__init__.py @@ -37,7 +37,7 @@ "mean", "rsqrt", "softmax", - "squeeze","unsqueeze", + "squeeze","unsqueeze","sliceselect","cat", #other "calculate_fan_in_and_fan_out", diff --git a/front/py/deepx/nn/functional/changeshape.py b/front/py/deepx/nn/functional/changeshape.py index db963f85..080c6ecc 100644 --- a/front/py/deepx/nn/functional/changeshape.py +++ b/front/py/deepx/nn/functional/changeshape.py @@ -1,6 +1,7 @@ +from typing import Union from deepx import Tensor -from .leaffunc_changeshape import reshape - +from .leaffunc_changeshape import reshape,indexselect, concat +from .leaffunc_init import newtensor,arange def squeeze(t:Tensor,dim:int)->Tensor: assert isinstance(dim,int) assert isinstance(t,Tensor) @@ -15,4 +16,17 @@ def unsqueeze(t:Tensor,dim:int)->Tensor: dim=dim%t.ndim newshape=list(t.shape) newshape.insert(dim,1) - return reshape(t,tuple(newshape)) \ No newline at end of file + return reshape(t,tuple(newshape)) + +def sliceselect(t:Tensor,sliceobj:slice,dim:int=-1,out:Union[Tensor,str]='')->Tensor: + assert isinstance(dim,int) + assert isinstance(sliceobj,slice) + assert isinstance(t,Tensor) + dim=dim%t.ndim + start=start = 0 if sliceobj.start is None else sliceobj.start % t.shape[dim] + stop= t.shape[dim] if sliceobj.stop is None else sliceobj.stop % t.shape[dim] + + index=arange(start,stop,dtype='int32') + return indexselect(t,index,dim=dim,out=out) + +cat= concat \ No newline at end of file diff --git a/front/py/deepx/nn/functional/leaffunc_changeshape.py b/front/py/deepx/nn/functional/leaffunc_changeshape.py index 2da83802..5f60bdad 100644 --- a/front/py/deepx/nn/functional/leaffunc_changeshape.py +++ b/front/py/deepx/nn/functional/leaffunc_changeshape.py @@ -80,17 +80,16 @@ def broadcastTo(t:Tensor,new_shape:tuple[int,...],out:Union[Tensor,str]='',requi return outtensor broadcast_to = broadcastTo -def indexselect(input:Tensor,indices:Tensor,gatheraxis:int,out:Union[Tensor,str]='')->Tensor: - assert gatheraxis>=0 and gatheraxisTensor: + assert dim>=0 and dimTe result=indexselect_func(self,index,gatheraxis,out) return result +@tensor_method +def sliceselect(self,index:slice,dim:int=0,out:Union[Tensor,str]='')->Tensor: + assert isinstance(index,slice) + gatheraxis=dim%self.ndim + from deepx.nn.functional import sliceselect as sliceselect_func + result=sliceselect_func(self,index,gatheraxis,out) + return result + @tensor_method def squeeze(self,dim:int)->Tensor: from deepx.nn.functional import squeeze as squeeze_func diff --git a/front/py/deepx/tensor/tensor.py b/front/py/deepx/tensor/tensor.py index bad202a2..1566e21f 100644 --- a/front/py/deepx/tensor/tensor.py +++ b/front/py/deepx/tensor/tensor.py @@ -160,34 +160,47 @@ def __rmatmul__(self, other:'Tensor'): return other.matmul(self) def __getitem__(self, idx): + # 简单操作 if isinstance(idx,Tensor): return self.indexselect(idx) - if isinstance(idx, int): - from deepx.tensor import newtensor - index=newtensor((1,),dtype='int32') - index.full_(idx) - return self.indexselect(index) + if isinstance(idx, int): + return self.sliceselect(slice(idx,idx+1)).squeeze(dim=0) + + ## 阶段1, + if isinstance(idx, slice): + indices = [idx] elif isinstance(idx, tuple): - indices=list(idx) + indices = list(idx) else: - raise TypeError("Index must be an integer or a slice") - - if Ellipsis in indices: - ellipsis_idx = indices.index(Ellipsis) - num_ellipsis = self.ndim - (len(indices) - 1) - indices[ellipsis_idx:ellipsis_idx + 1] = [slice(None)] * num_ellipsis - - print(indices) - need_reshape=False - need_indexselect=False - for i, ix in enumerate(indices): - if ix is None: - need_reshape = True - elif isinstance(ix, slice): - if ix != slice(None, None, None): - need_indexselect = True - print(need_reshape,need_indexselect) - return self + raise TypeError(f"Index must be an integer, slice, tuple, or Tensor, not {type(idx).__name__}") + # 阶段2 + result = self + new_axis_positions = [] + dim_cursor = 0 + + for item in indices: + if item is None: + # 如果是 None,则表示在该位置添加一个新的维度 + new_axis_positions.append(dim_cursor) + continue + if item == Ellipsis: + num_ellipsis = self.ndim - len(indices) + 1 + dim_cursor += num_ellipsis + continue + # 如果是完整的切片 (e.g., ':'),则无需操作,直接进入下一维度 + if item == slice(None, None, None): + dim_cursor += 1 + continue + result=result.sliceselect(item,dim=dim_cursor) + dim_cursor += 1 + + # 2. 在指定位置添加新维度(由 None 产生) + i=0 + for pos in sorted(new_axis_positions): + result = result.unsqueeze(pos+i) + i += 1 + + return result #shape操作 @property diff --git a/front/py/deepx/transformer/models/llama/attention.py b/front/py/deepx/transformer/models/llama/attention.py index eb37f731..6100c828 100644 --- a/front/py/deepx/transformer/models/llama/attention.py +++ b/front/py/deepx/transformer/models/llama/attention.py @@ -1,15 +1,13 @@ from typing import Optional,Tuple from deepx.nn.modules import Module,Linear -from deepx import Tensor,matmul,softmax,concat,arange,dropout as dropout_func +from deepx import Tensor,matmul,softmax,cat,dropout as dropout_func def rotate_half(x:Tensor): - index_front=arange(0,x.shape[-1]//2,dtype="int32") - index_back=arange(x.shape[-1]//2,x.shape[-1],dtype="int32") - x1 = x.indexselect(gatheraxis=-1,index=index_front) - x2 = x.indexselect(gatheraxis=-1,index=index_back) - return concat((-x2, x1,), dim=-1) + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return cat((-x2, x1,), dim=-1) def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1): cos = cos.unsqueeze(unsqueeze_dim) diff --git a/front/py/examples/1_tensor/getitem/getitem.py b/front/py/examples/1_tensor/getitem/getitem.py index 74ba95c2..a903321a 100644 --- a/front/py/examples/1_tensor/getitem/getitem.py +++ b/front/py/examples/1_tensor/getitem/getitem.py @@ -1,10 +1,13 @@ -from deepx import newtensor,Tensor +from deepx import newtensor,arange t = newtensor((2, 3, 13)) -t.full_(1) +t.arange_() print() t2 = t[None, :, None] t2.print() t3=t[:,None,:] t3.print() -t4=t[..., : t.shape[-1] // 2] -t4.print() +x=t +x1 = x[..., : x.shape[-1] // 2] +x2 = x[..., x.shape[-1] // 2 :] +x1.print() +x2.print() diff --git a/front/py/examples/1_tensor/getitem/getitem_torch.py b/front/py/examples/1_tensor/getitem/getitem_torch.py index 48108cd4..f2394827 100644 --- a/front/py/examples/1_tensor/getitem/getitem_torch.py +++ b/front/py/examples/1_tensor/getitem/getitem_torch.py @@ -1,5 +1,11 @@ import torch -t = torch.full((2, 3, 4), 1) +t = torch.full((2, 3, 13), 1) t2 = t[None, :, None] -print(t2) \ No newline at end of file +print(t2.shape) +print(t2) +x=t +x1 = x[..., : x.shape[-1] // 2] +x2 = x[..., x.shape[-1] // 2 :] +print(x1) +print(x2)