Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions front/py/deepx/nn/functional/changeshape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union
from deepx import Tensor
from .leaffunc_changeshape import reshape,indexselect, concat
from .leaffunc_changeshape import reshape,indexselect, concat,broadcastTo
from .leaffunc_init import newtensor,arange
def squeeze(t:Tensor,dim:int)->Tensor:
assert isinstance(dim,int)
Expand All @@ -13,7 +13,7 @@ def squeeze(t:Tensor,dim:int)->Tensor:
def unsqueeze(t:Tensor,dim:int)->Tensor:
assert isinstance(dim,int)
assert isinstance(t,Tensor)
dim=dim%t.ndim
dim = dim % (t.ndim + 1)
newshape=list(t.shape)
newshape.insert(dim,1)
return reshape(t,tuple(newshape))
Expand All @@ -29,4 +29,7 @@ def sliceselect(t:Tensor,sliceobj:slice,dim:int=-1,out:Union[Tensor,str]='')->Te
index=arange(start,stop,dtype='int32')
return indexselect(t,index,dim=dim,out=out)

cat= concat
cat= concat
# 参考 PyTorch 文档,broadcastTo和expand是作用一样
# https://docs.pytorch.org/docs/stable/generated/torch.broadcast_to.html
expand = broadcastTo
25 changes: 18 additions & 7 deletions front/py/deepx/nn/functional/leaffunc_changeshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,36 @@ def concat(tensors:Union[list[Tensor],tuple[Tensor,...]],dim:int,out:Union[Tenso
rtf_concat(tensors,dim,outtensor,defaultauthor['concat'])
return outtensor

def broadcastTo(t:Tensor,new_shape:tuple[int,...],out:Union[Tensor,str]='',requires_grad:bool=False,author='miaobyte')->Tensor:
assert isinstance(new_shape,tuple)
for i in new_shape:
assert isinstance(i,int) and i>0

def broadcastTo(t:Tensor,newshape:tuple[int,...],out:Union[Tensor,str]='')->Tensor:
assert isinstance(newshape,tuple)
assert len(newshape)==t.ndim
new_shape=[]
for idx,i in enumerate(newshape):
assert isinstance(i,int)
if i==-1:
new_shape.append(t.shape[idx])
elif i>0:
new_shape.append(i)
else:
raise ValueError(f"新形状参数不合法,维度 {idx} 的值{i}")
new_shape=tuple(new_shape)
if t.shape==new_shape:
return t

bshape=Shape.broadcast_shape(t.shape,new_shape)
if bshape!=tuple(new_shape):
raise ValueError(f"广播失败:{t.shape} 无法广播为 {new_shape} ")
if bshape!=new_shape:
raise ValueError(f"广播失败:{t.shape} 无法广播为 {new_shape},请先reshape")
outtensor=out
if isinstance(out,str) or out is None:
outshape=new_shape
outtensor=newtensor(outshape,dtype=t.dtype,name=out)
from .rtf_changeshape import rtf_broadcastTo
rtf_broadcastTo(t,new_shape,outtensor,defaultauthor['broadcastTo'])
return outtensor

broadcast_to = broadcastTo


def indexselect(input:Tensor,indices:Tensor,dim:int,out:Union[Tensor,str]='')->Tensor:
assert dim>=0 and dim<input.ndim
outtensor=out
Expand Down
13 changes: 8 additions & 5 deletions front/py/deepx/tensor/changeshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def transpose_(self):
transpose_func(self,self)
return self

# broadcast_to==broadcastTo==expand
# https://docs.pytorch.org/docs/stable/generated/torch.broadcast_to.html
@tensor_method
def broadcastTo(self,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor:
from deepx.nn.functional import broadcastTo as broadcastTo_func
Expand All @@ -54,6 +56,12 @@ def broadcast_to(self,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor:
result=broadcast_to_func(self,shape,out)
return result

@tensor_method
def expand(self,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor:
from deepx.nn.functional import broadcastTo as expand_func
result=expand_func(self,shape,out)
return result

@tensor_method
def indexselect(self,index:Tensor,gatheraxis:int=0,out:Union[Tensor,str]='')->Tensor:
assert isinstance(index,Tensor)
Expand Down Expand Up @@ -88,8 +96,3 @@ def repeat(self,repeats:tuple[int,...])->Tensor:
result=repeat_func(self,repeats)
return result

# @tensor_method
# def expand(self,shape:tuple)->Tensor:
# from deepx.nn.functional import expand as expand_func
# result=expand_func(self,shape,False)
# return result
38 changes: 37 additions & 1 deletion front/py/deepx/tensor/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,45 @@ def switch_(self,cases:Union[Tensor,float,int]):
# 类型转换
@tensor_method
def todtype(self,dtype:str):
assert isinstance(dtype,str)
if str== self.dtype:
return self
from deepx.nn.functional import newtensor
dest=newtensor(self.shape,dtype=dtype)
from deepx.nn.functional import todtype as todtype_func
todtype_func(self,dest)
return dest


@tensor_method
def double(self)->Tensor:
"""将张量转换为float64类型"""
return self.todtype('float64')

@tensor_method
def float(self)->Tensor:
"""将张量转换为float32类型"""
return self.todtype('float32')

@tensor_method
def half(self)->Tensor:
"""将张量转换为float16类型"""
return self.todtype('float16')

@tensor_method
def long(self)->Tensor:
"""将张量转换为int64类型"""
return self.todtype('int64')


@tensor_method
def int(self)->Tensor:
"""将张量转换为int32类型"""
return self.todtype('int32')

@tensor_method
def bool(self)->Tensor:
"""将张量转换为bool类型"""
return self.todtype('bool')



31 changes: 7 additions & 24 deletions front/py/deepx/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,34 +142,17 @@ def matmul(cls,shape:tuple[int],other:tuple[int])->tuple[int]:

@classmethod
def broadcast_shape(cls,shape_a: tuple[int,...], shape_b: tuple[int,...]) -> tuple[int,...]:
"""计算两个形状的广播后形状"""
# 获取形状的长度
len_a, len_b = len(shape_a), len(shape_b)

# 创建结果形状
assert isinstance(shape_a,tuple) and isinstance(shape_b,tuple)
assert len(shape_b)==len(shape_a)
"""计算两个形状的广播后形状(长度必须一致)"""
result_shape = []

# 从右往左对齐并计算每个维度
for i in range(1, min(len_a, len_b) + 1):
dim_a = shape_a[-i]
dim_b = shape_b[-i]

for dim_a, dim_b in zip(shape_a, shape_b):
if dim_a == 1 or dim_b == 1:
# 广播规则:如果一个维度为1,取另一个维度的值
result_shape.insert(0, max(dim_a, dim_b))
result_shape.append(max(dim_a, dim_b))
elif dim_a == dim_b:
# 维度相同,保持不变
result_shape.insert(0, dim_a)
result_shape.append(dim_a)
else:
# 维度不同且都不为1,无法广播
raise ValueError(f"无法广播的形状:{shape_a} 和 {shape_b}")

# 添加较长形状中多出的前导维度
if len_a > len_b:
result_shape = list(shape_a[:len_a - len_b]) + result_shape
elif len_b > len_a:
result_shape = list(shape_b[:len_b - len_a]) + result_shape

raise ValueError(f"无法广播的形状:{shape_a} 和 {shape_b},请先reshape")
return tuple(result_shape)


Expand Down
2 changes: 0 additions & 2 deletions front/py/deepx/tensor/tensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Optional,Union,TypeAlias

from triton.language.semantic import equal

from .shape import Shape


Expand Down
10 changes: 5 additions & 5 deletions front/py/deepx/transformer/models/llama/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def __init__(self,config:dict):

def forward(self, x, position_ids):
# 扩展旋转频率
inv_freq_expanded = self.inv_freq.unsqueeze(dim=0).unsqueeze(dim=2).todtype('float32')
broadcast_shape=(position_ids.shape[0], self.inv_freq.shape[0], 1)
inv_freq_expanded = inv_freq_expanded.reshape(broadcast_shape)

inv_freq_expanded = self.inv_freq[None, :, None].todtype('float32').expand((position_ids.shape[0], -1, 1))

# 使用torch.unsqueeze和type转换替代索引操作
position_ids_expanded = position_ids.unsqueeze(dim=1).todtype(x.dtype)
position_ids_expanded = position_ids[:, None, :].float()


# 计算频率
freqs = (inv_freq_expanded @ position_ids_expanded).T
# 拼接频率
Expand Down
6 changes: 6 additions & 0 deletions front/py/examples/1_tensor/getitem/getitem_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deepx import newtensor,arange
t = newtensor((64,))
t.arange_()
print()
t2 = t[None, :, None]
t2.print()
4 changes: 2 additions & 2 deletions front/py/examples/2_ir/4_changeshape_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@


########====DEEPX====########
from deepx import Tensor,arange,broadcastTo
from deepx import arange

a=arange(start=0,end=4*2*3,name="a").reshape_((4,2,3))
b=arange(start=0,end=2,name='b').reshape((2,1))
bb=b.broadcastTo( a.shape,out="b.broadcasted")
bb=b.unsqueeze(0).broadcastTo(a.shape,out="b.broadcasted")
bb.print()

c=a[None:,]
Expand Down
15 changes: 0 additions & 15 deletions front/py/examples/2_ir/4_changeshape_broadcast_add.py

This file was deleted.

8 changes: 4 additions & 4 deletions front/py/examples/2_ir/4_changeshape_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from deepx import Tensor,zeros, ones, concat


t1 = ones([3,4,5],dtype='float32',name='t1')
t2=ones([3,4,5],dtype='float32',name='t2')
t3=ones([3,4,5],dtype='float32',name='t3')
t1 = ones( (3,4,5),dtype='float32',name='t1')
t2=ones((3,4,5),dtype='float32',name='t2')
t3=ones((3,4,5),dtype='float32',name='t3')

t=concat([t1,t2,t3],dim=1,out='t')
t=concat((t1,t2,t3),dim=1,out='t')
t.print()
2 changes: 1 addition & 1 deletion front/py/examples/3_functional/changeshape_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#######-----------------deepx-----------------#######
from deepx import Tensor,broadcast_to,arange
deepx_x = arange(0,6).reshape_((1,2,3)) # shape=(2,3)
deepx_y = broadcast_to(deepx_x, (3,2,3)) # 需要原维度为1
deepx_y = deepx_x.broadcast_to((3,2,3)) # 需要原维度为1
deepx_y.print()


Expand Down
2 changes: 1 addition & 1 deletion front/py/examples/3_module/1_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepx import ones

net = Linear(64, 4)
input=ones(1,64,name='input')
input=ones((1,64),name='input')
out=net.forward(input)
out.print()

Loading