Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion front/py/deepx/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"mean",
"rsqrt",
"softmax",
"squeeze","unsqueeze",
"squeeze","unsqueeze","sliceselect","cat",

#other
"calculate_fan_in_and_fan_out",
Expand Down
20 changes: 17 additions & 3 deletions front/py/deepx/nn/functional/changeshape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union
from deepx import Tensor
from .leaffunc_changeshape import reshape

from .leaffunc_changeshape import reshape,indexselect, concat
from .leaffunc_init import newtensor,arange
def squeeze(t:Tensor,dim:int)->Tensor:
assert isinstance(dim,int)
assert isinstance(t,Tensor)
Expand All @@ -15,4 +16,17 @@ def unsqueeze(t:Tensor,dim:int)->Tensor:
dim=dim%t.ndim
newshape=list(t.shape)
newshape.insert(dim,1)
return reshape(t,tuple(newshape))
return reshape(t,tuple(newshape))

def sliceselect(t:Tensor,sliceobj:slice,dim:int=-1,out:Union[Tensor,str]='')->Tensor:
assert isinstance(dim,int)
assert isinstance(sliceobj,slice)
assert isinstance(t,Tensor)
dim=dim%t.ndim
start=start = 0 if sliceobj.start is None else sliceobj.start % t.shape[dim]
stop= t.shape[dim] if sliceobj.stop is None else sliceobj.stop % t.shape[dim]

index=arange(start,stop,dtype='int32')
return indexselect(t,index,dim=dim,out=out)

cat= concat
9 changes: 4 additions & 5 deletions front/py/deepx/nn/functional/leaffunc_changeshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,16 @@ def broadcastTo(t:Tensor,new_shape:tuple[int,...],out:Union[Tensor,str]='',requi
return outtensor
broadcast_to = broadcastTo

def indexselect(input:Tensor,indices:Tensor,gatheraxis:int,out:Union[Tensor,str]='')->Tensor:
assert gatheraxis>=0 and gatheraxis<input.ndim

def indexselect(input:Tensor,indices:Tensor,dim:int,out:Union[Tensor,str]='')->Tensor:
assert dim>=0 and dim<input.ndim
outtensor=out
if isinstance(out,str) or out is None:
outshape=Shape.indexselectshape(input.shape,indices.shape,gatheraxis)
outshape=Shape.indexselectshape(input.shape,indices.shape,dim)
outtensor=newtensor(outshape,dtype=input.dtype,name=out)
assert outtensor.shape==outshape

from .rtf_changeshape import rtf_indexselect
rtf_indexselect(input,indices,gatheraxis,outtensor,defaultauthor['indexselect'])
rtf_indexselect(input,indices,dim,outtensor,defaultauthor['indexselect'])
return outtensor

def repeat(input:Tensor,repeats:tuple[int,...],out:Union[Tensor,str]=''):
Expand Down
2 changes: 1 addition & 1 deletion front/py/deepx/nn/functional/leaffunc_life.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from deepx.tensor import Tensor
from typing import Union

def newtensor(shape:tuple[int,...],dtype:str='float32',name:str=None):
def newtensor(shape:tuple[int,...],dtype:str='float32',name:str=None)->Tensor:
assert isinstance(shape,tuple)
for i in shape:
assert isinstance(i,int)
Expand Down
8 changes: 8 additions & 0 deletions front/py/deepx/tensor/changeshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def indexselect(self,index:Tensor,gatheraxis:int=0,out:Union[Tensor,str]='')->Te
result=indexselect_func(self,index,gatheraxis,out)
return result

@tensor_method
def sliceselect(self,index:slice,dim:int=0,out:Union[Tensor,str]='')->Tensor:
assert isinstance(index,slice)
gatheraxis=dim%self.ndim
from deepx.nn.functional import sliceselect as sliceselect_func
result=sliceselect_func(self,index,gatheraxis,out)
return result

@tensor_method
def squeeze(self,dim:int)->Tensor:
from deepx.nn.functional import squeeze as squeeze_func
Expand Down
44 changes: 42 additions & 2 deletions front/py/deepx/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,48 @@ def __matmul__(self, other:'Tensor'):
def __rmatmul__(self, other:'Tensor'):
return other.matmul(self)

def __getitem__(self, index:'Tensor'):
return self.indexselect(index)
def __getitem__(self, idx):
# 简单操作
if isinstance(idx,Tensor):
return self.indexselect(idx)
if isinstance(idx, int):
return self.sliceselect(slice(idx,idx+1)).squeeze(dim=0)

## 阶段1,
if isinstance(idx, slice):
indices = [idx]
elif isinstance(idx, tuple):
indices = list(idx)
else:
raise TypeError(f"Index must be an integer, slice, tuple, or Tensor, not {type(idx).__name__}")
# 阶段2
result = self
new_axis_positions = []
dim_cursor = 0

for item in indices:
if item is None:
# 如果是 None,则表示在该位置添加一个新的维度
new_axis_positions.append(dim_cursor)
continue
if item == Ellipsis:
num_ellipsis = self.ndim - len(indices) + 1
dim_cursor += num_ellipsis
continue
# 如果是完整的切片 (e.g., ':'),则无需操作,直接进入下一维度
if item == slice(None, None, None):
dim_cursor += 1
continue
result=result.sliceselect(item,dim=dim_cursor)
dim_cursor += 1

# 2. 在指定位置添加新维度(由 None 产生)
i=0
for pos in sorted(new_axis_positions):
result = result.unsqueeze(pos+i)
i += 1

return result

#shape操作
@property
Expand Down
10 changes: 4 additions & 6 deletions front/py/deepx/transformer/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from typing import Optional,Tuple
from deepx.nn.modules import Module,Linear
from deepx import Tensor,matmul,softmax,concat,arange,dropout as dropout_func
from deepx import Tensor,matmul,softmax,cat,dropout as dropout_func



def rotate_half(x:Tensor):
index_front=arange(0,x.shape[-1]//2,dtype="int32")
index_back=arange(x.shape[-1]//2,x.shape[-1],dtype="int32")
x1 = x.indexselect(gatheraxis=-1,index=index_front)
x2 = x.indexselect(gatheraxis=-1,index=index_back)
return concat((-x2, x1,), dim=-1)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return cat((-x2, x1,), dim=-1)

def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1):
cos = cos.unsqueeze(unsqueeze_dim)
Expand Down
14 changes: 0 additions & 14 deletions front/py/examples/1_tensor/getitem.py

This file was deleted.

13 changes: 13 additions & 0 deletions front/py/examples/1_tensor/getitem/getitem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from deepx import newtensor,arange
t = newtensor((2, 3, 13))
t.arange_()
print()
t2 = t[None, :, None]
t2.print()
t3=t[:,None,:]
t3.print()
x=t
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x1.print()
x2.print()
11 changes: 11 additions & 0 deletions front/py/examples/1_tensor/getitem/getitem_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

t = torch.full((2, 3, 13), 1)
t2 = t[None, :, None]
print(t2.shape)
print(t2)
x=t
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
print(x1)
print(x2)
53 changes: 53 additions & 0 deletions front/py/examples/1_tensor/getitem/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
class Tensor:
def __init__(self, data):
# 假设 data 是嵌套 list 或 numpy ndarray
import numpy as np
self.data = np.array(data)

def __getitem__(self, idx):
# 支持 None, int, slice, tuple 等
# 重点:遇到 None 时,插入新轴
import numpy as np

if not isinstance(idx, tuple):
idx = (idx,)

# 统计原始索引和 None 的位置,组装成新的索引
new_idx = []
expand_axes = []
for i, ix in enumerate(idx):
if ix is None:
expand_axes.append(len(new_idx))
else:
new_idx.append(ix)

# 先索引
result = self.data[tuple(new_idx)]
# 再插入新轴
for ax in expand_axes:
result = np.expand_dims(result, axis=ax)

# 返回新 Tensor
ret = Tensor(result)
return ret

@property
def shape(self):
return self.data.shape

def __repr__(self):
return f"Tensor(shape={self.shape}, data=\n{self.data})"


# 测试代码
t = Tensor([[1, 2, 3], [4, 5, 6]])
print("原始shape:", t.shape) # (2, 3)

t2 = t[None]
print("t[None].shape:", t2.shape) # (1, 2, 3)

t3 = t[:, None]
print("t[:, None].shape:", t3.shape) # (2, 1, 3)

t4 = t[None, :, None]
print("t[None, :, None].shape:", t4.shape) # (1, 2, 1, 3)
Loading