diff --git a/front/py/deepx/nn/functional/__init__.py b/front/py/deepx/nn/functional/__init__.py index 3f5d33c..deeafa5 100644 --- a/front/py/deepx/nn/functional/__init__.py +++ b/front/py/deepx/nn/functional/__init__.py @@ -37,7 +37,7 @@ "mean", "rsqrt", "softmax", - "squeeze","unsqueeze", + "squeeze","unsqueeze","sliceselect","cat", #other "calculate_fan_in_and_fan_out", diff --git a/front/py/deepx/nn/functional/changeshape.py b/front/py/deepx/nn/functional/changeshape.py index db963f8..080c6ec 100644 --- a/front/py/deepx/nn/functional/changeshape.py +++ b/front/py/deepx/nn/functional/changeshape.py @@ -1,6 +1,7 @@ +from typing import Union from deepx import Tensor -from .leaffunc_changeshape import reshape - +from .leaffunc_changeshape import reshape,indexselect, concat +from .leaffunc_init import newtensor,arange def squeeze(t:Tensor,dim:int)->Tensor: assert isinstance(dim,int) assert isinstance(t,Tensor) @@ -15,4 +16,17 @@ def unsqueeze(t:Tensor,dim:int)->Tensor: dim=dim%t.ndim newshape=list(t.shape) newshape.insert(dim,1) - return reshape(t,tuple(newshape)) \ No newline at end of file + return reshape(t,tuple(newshape)) + +def sliceselect(t:Tensor,sliceobj:slice,dim:int=-1,out:Union[Tensor,str]='')->Tensor: + assert isinstance(dim,int) + assert isinstance(sliceobj,slice) + assert isinstance(t,Tensor) + dim=dim%t.ndim + start=start = 0 if sliceobj.start is None else sliceobj.start % t.shape[dim] + stop= t.shape[dim] if sliceobj.stop is None else sliceobj.stop % t.shape[dim] + + index=arange(start,stop,dtype='int32') + return indexselect(t,index,dim=dim,out=out) + +cat= concat \ No newline at end of file diff --git a/front/py/deepx/nn/functional/leaffunc_changeshape.py b/front/py/deepx/nn/functional/leaffunc_changeshape.py index 2da8380..5f60bda 100644 --- a/front/py/deepx/nn/functional/leaffunc_changeshape.py +++ b/front/py/deepx/nn/functional/leaffunc_changeshape.py @@ -80,17 +80,16 @@ def broadcastTo(t:Tensor,new_shape:tuple[int,...],out:Union[Tensor,str]='',requi return outtensor broadcast_to = broadcastTo -def indexselect(input:Tensor,indices:Tensor,gatheraxis:int,out:Union[Tensor,str]='')->Tensor: - assert gatheraxis>=0 and gatheraxisTensor: + assert dim>=0 and dimTensor: assert isinstance(shape,tuple) for i in shape: assert isinstance(i,int) diff --git a/front/py/deepx/tensor/changeshape.py b/front/py/deepx/tensor/changeshape.py index 6b0b309..b29d858 100644 --- a/front/py/deepx/tensor/changeshape.py +++ b/front/py/deepx/tensor/changeshape.py @@ -62,6 +62,14 @@ def indexselect(self,index:Tensor,gatheraxis:int=0,out:Union[Tensor,str]='')->Te result=indexselect_func(self,index,gatheraxis,out) return result +@tensor_method +def sliceselect(self,index:slice,dim:int=0,out:Union[Tensor,str]='')->Tensor: + assert isinstance(index,slice) + gatheraxis=dim%self.ndim + from deepx.nn.functional import sliceselect as sliceselect_func + result=sliceselect_func(self,index,gatheraxis,out) + return result + @tensor_method def squeeze(self,dim:int)->Tensor: from deepx.nn.functional import squeeze as squeeze_func diff --git a/front/py/deepx/tensor/tensor.py b/front/py/deepx/tensor/tensor.py index ac63d9a..1566e21 100644 --- a/front/py/deepx/tensor/tensor.py +++ b/front/py/deepx/tensor/tensor.py @@ -159,8 +159,48 @@ def __matmul__(self, other:'Tensor'): def __rmatmul__(self, other:'Tensor'): return other.matmul(self) - def __getitem__(self, index:'Tensor'): - return self.indexselect(index) + def __getitem__(self, idx): + # 简单操作 + if isinstance(idx,Tensor): + return self.indexselect(idx) + if isinstance(idx, int): + return self.sliceselect(slice(idx,idx+1)).squeeze(dim=0) + + ## 阶段1, + if isinstance(idx, slice): + indices = [idx] + elif isinstance(idx, tuple): + indices = list(idx) + else: + raise TypeError(f"Index must be an integer, slice, tuple, or Tensor, not {type(idx).__name__}") + # 阶段2 + result = self + new_axis_positions = [] + dim_cursor = 0 + + for item in indices: + if item is None: + # 如果是 None,则表示在该位置添加一个新的维度 + new_axis_positions.append(dim_cursor) + continue + if item == Ellipsis: + num_ellipsis = self.ndim - len(indices) + 1 + dim_cursor += num_ellipsis + continue + # 如果是完整的切片 (e.g., ':'),则无需操作,直接进入下一维度 + if item == slice(None, None, None): + dim_cursor += 1 + continue + result=result.sliceselect(item,dim=dim_cursor) + dim_cursor += 1 + + # 2. 在指定位置添加新维度(由 None 产生) + i=0 + for pos in sorted(new_axis_positions): + result = result.unsqueeze(pos+i) + i += 1 + + return result #shape操作 @property diff --git a/front/py/deepx/transformer/models/llama/attention.py b/front/py/deepx/transformer/models/llama/attention.py index eb37f73..6100c82 100644 --- a/front/py/deepx/transformer/models/llama/attention.py +++ b/front/py/deepx/transformer/models/llama/attention.py @@ -1,15 +1,13 @@ from typing import Optional,Tuple from deepx.nn.modules import Module,Linear -from deepx import Tensor,matmul,softmax,concat,arange,dropout as dropout_func +from deepx import Tensor,matmul,softmax,cat,dropout as dropout_func def rotate_half(x:Tensor): - index_front=arange(0,x.shape[-1]//2,dtype="int32") - index_back=arange(x.shape[-1]//2,x.shape[-1],dtype="int32") - x1 = x.indexselect(gatheraxis=-1,index=index_front) - x2 = x.indexselect(gatheraxis=-1,index=index_back) - return concat((-x2, x1,), dim=-1) + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return cat((-x2, x1,), dim=-1) def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1): cos = cos.unsqueeze(unsqueeze_dim) diff --git a/front/py/examples/1_tensor/getitem.py b/front/py/examples/1_tensor/getitem.py deleted file mode 100644 index c074676..0000000 --- a/front/py/examples/1_tensor/getitem.py +++ /dev/null @@ -1,14 +0,0 @@ - -def deepx_getitem(): - from deepx import newtensor - t=newtensor((2,3,4)).full_(1) - t2=t[None, :, None] - t2.print() -def torch_getitem(): - import torch - t=torch.full((2,3,4),1) - t2=t[None, :, None] - print(t2) -if __name__ == "__main__": - deepx_getitem() - torch_getitem() \ No newline at end of file diff --git a/front/py/examples/1_tensor/getitem/getitem.py b/front/py/examples/1_tensor/getitem/getitem.py new file mode 100644 index 0000000..a903321 --- /dev/null +++ b/front/py/examples/1_tensor/getitem/getitem.py @@ -0,0 +1,13 @@ +from deepx import newtensor,arange +t = newtensor((2, 3, 13)) +t.arange_() +print() +t2 = t[None, :, None] +t2.print() +t3=t[:,None,:] +t3.print() +x=t +x1 = x[..., : x.shape[-1] // 2] +x2 = x[..., x.shape[-1] // 2 :] +x1.print() +x2.print() diff --git a/front/py/examples/1_tensor/getitem/getitem_torch.py b/front/py/examples/1_tensor/getitem/getitem_torch.py new file mode 100644 index 0000000..f239482 --- /dev/null +++ b/front/py/examples/1_tensor/getitem/getitem_torch.py @@ -0,0 +1,11 @@ +import torch + +t = torch.full((2, 3, 13), 1) +t2 = t[None, :, None] +print(t2.shape) +print(t2) +x=t +x1 = x[..., : x.shape[-1] // 2] +x2 = x[..., x.shape[-1] // 2 :] +print(x1) +print(x2) diff --git a/front/py/examples/1_tensor/getitem/tensor.py b/front/py/examples/1_tensor/getitem/tensor.py new file mode 100644 index 0000000..e280660 --- /dev/null +++ b/front/py/examples/1_tensor/getitem/tensor.py @@ -0,0 +1,53 @@ +class Tensor: + def __init__(self, data): + # 假设 data 是嵌套 list 或 numpy ndarray + import numpy as np + self.data = np.array(data) + + def __getitem__(self, idx): + # 支持 None, int, slice, tuple 等 + # 重点:遇到 None 时,插入新轴 + import numpy as np + + if not isinstance(idx, tuple): + idx = (idx,) + + # 统计原始索引和 None 的位置,组装成新的索引 + new_idx = [] + expand_axes = [] + for i, ix in enumerate(idx): + if ix is None: + expand_axes.append(len(new_idx)) + else: + new_idx.append(ix) + + # 先索引 + result = self.data[tuple(new_idx)] + # 再插入新轴 + for ax in expand_axes: + result = np.expand_dims(result, axis=ax) + + # 返回新 Tensor + ret = Tensor(result) + return ret + + @property + def shape(self): + return self.data.shape + + def __repr__(self): + return f"Tensor(shape={self.shape}, data=\n{self.data})" + + +# 测试代码 +t = Tensor([[1, 2, 3], [4, 5, 6]]) +print("原始shape:", t.shape) # (2, 3) + +t2 = t[None] +print("t[None].shape:", t2.shape) # (1, 2, 3) + +t3 = t[:, None] +print("t[:, None].shape:", t3.shape) # (2, 1, 3) + +t4 = t[None, :, None] +print("t[None, :, None].shape:", t4.shape) # (1, 2, 1, 3) \ No newline at end of file