Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion front/py/deepx/nn/functional/changeshape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from typing import Union
from deepx import Tensor
from .leaffunc_changeshape import reshape,indexselect, concat,broadcastTo
from .leaffunc_changeshape import reshape,indexselect, concat,broadcastTo,permute
from .leaffunc_init import newtensor,arange


def transpose(t:Tensor,dim0:int,dim1:int,out:Union[Tensor,str]='')->Tensor:
dimorder = list(range(t.ndim))
dimorder[dim0],dimorder[dim1]=dimorder[dim1],dimorder[dim0]
return permute(t,tuple(dimorder),out)


def squeeze(t:Tensor,dim:int)->Tensor:
assert isinstance(dim,int)
assert isinstance(t,Tensor)
Expand Down
6 changes: 0 additions & 6 deletions front/py/deepx/nn/functional/leaffunc_changeshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def reshape(t:Tensor,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor:
for i in shape:
assert isinstance(i,int) and i>0

outtensor=out
if isinstance(out,str) or out is None:
outshape=shape
outtensor=newtensor(outshape,dtype=t.dtype,name=out)
Expand Down Expand Up @@ -40,11 +39,6 @@ def permute(t:Tensor,
rtf_transpose(t,dimorder,outtensor,defaultauthor['transpose'])
return outtensor

def transpose(t:Tensor,out:Union[Tensor,str]='')->Tensor:
dimorder = list(range(t.ndim))
dimorder[-1],dimorder[-2]=dimorder[-2],dimorder[-1]
return permute(t,tuple(dimorder),out)



def concat(tensors:Union[list[Tensor],tuple[Tensor,...]],dim:int,out:Union[Tensor,str]='')->Tensor:
Expand Down
2 changes: 1 addition & 1 deletion front/py/deepx/nn/functional/rtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def A_op_C(op:str,a:Tensor,out:Tensor,author='miaobyte'):
ir=DeepxIR(op, args, returns,author)
send(ir)

def A_b1_b2_op_C(op:str,a:Tensor,b1:tuple[int],b2:bool,out:Tensor,author='miaobyte'):
def A_b1_b2_op_C(op:str,a:Tensor,b1:tuple[int,...],b2:bool,out:Tensor,author='miaobyte'):
args=[Param.tensor(a),Param.vector(b1,'int32'),Param.varbool(b2)]
returns=[Param.tensor(out)]
ir=DeepxIR(op, args, returns,author)
Expand Down
6 changes: 3 additions & 3 deletions front/py/deepx/nn/functional/rtf_changeshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from deepx.nn.deepxir import DeepxIR,Param
from deepx.scheduler import send

def rtf_reshape(t:Tensor,shape:tuple[int],out:Tensor,author='miaobyte'):
def rtf_reshape(t:Tensor,shape:tuple[int,...],out:Tensor,author='miaobyte'):
args=[Param.tensor(t),Param.vector(shape,'int32')]
returns=[Param.tensor(out)]
ir=DeepxIR("reshape", args, returns,author)
send(ir)


def rtf_transpose(t:Tensor,dimorder:tuple[int],out:Tensor,author='miaobyte'):
def rtf_transpose(t:Tensor,dimorder:tuple[int,...],out:Tensor,author='miaobyte'):
args=[Param.tensor(t),Param.vector(dimorder,'int32')]
returns=[Param.tensor(out)]
ir=DeepxIR("transpose", args, returns,author)
Expand All @@ -22,7 +22,7 @@ def rtf_concat(tensors:tuple[Tensor],dim:int,out:Tensor,author='miaobyte'):
send(ir)


def rtf_broadcastTo(t:Tensor,new_shape:tuple[int],out:Tensor,author='miaobyte'):
def rtf_broadcastTo(t:Tensor,new_shape:tuple[int,...],out:Tensor,author='miaobyte'):
args=[Param.tensor(t),Param.vector(new_shape,'int32')]
returns=[Param.tensor(out)]
ir=DeepxIR("broadcastTo", args, returns,author)
Expand Down
8 changes: 4 additions & 4 deletions front/py/deepx/nn/functional/rtf_reduce.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from deepx.tensor import Tensor
from .rtf import A_b1_b2_op_C

def rtf_sum(a:Tensor,dim:tuple[int],keepdim:bool,out: Tensor, author:str='miaobyte')->Tensor:
def rtf_sum(a:Tensor,dim:tuple[int,...],keepdim:bool,out: Tensor, author:str='miaobyte')->Tensor:
A_b1_b2_op_C("sum",a,dim,keepdim,out,author)


def rtf_prod(a:Tensor,dim:tuple[int],keepdim:bool,out:Tensor, author:str='miaobyte')->Tensor:
def rtf_prod(a:Tensor,dim:tuple[int,...],keepdim:bool,out:Tensor, author:str='miaobyte')->Tensor:
A_b1_b2_op_C("prod",a,dim,keepdim,out,author)


def rtf_reducemax(a:Tensor,dim:tuple[int],keepdim:bool,out:Tensor, author:str='miaobyte')->Tensor:
def rtf_reducemax(a:Tensor,dim:tuple[int,...],keepdim:bool,out:Tensor, author:str='miaobyte')->Tensor:
A_b1_b2_op_C("reducemax",a,dim,keepdim,out,author)


def rtf_reducemin(a:Tensor,dim:tuple[int],keepdim:bool,out:Tensor, author:str='miaobyte')->Tensor:
def rtf_reducemin(a:Tensor,dim:tuple[int,...],keepdim:bool,out:Tensor, author:str='miaobyte')->Tensor:
A_b1_b2_op_C("reducemin",a,dim,keepdim,out,author)

2 changes: 1 addition & 1 deletion front/py/deepx/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def reset_parameters(self) -> None:

def forward(self, input: Tensor) -> Tensor:
#`y = xA^T + b`
y=input @ self.weight.T
y=input @ self.weight.mT
oldshape=y.shape
if self.bias is not None:
y.reshape_(tuple(y.shape[1:]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"silu":swish_fn,
}

class LlamaMLP(Module):
class MLP(Module):
def __init__(self, config:dict):
super().__init__()
# 输入层大小
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .attention import *

__all__ = [
"scaled_dot_product_attention",
"LlamaRotaryEmbedding",
"rotate_half"
]
34 changes: 34 additions & 0 deletions front/py/deepx/nn/modules/transformer/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional,Tuple
from deepx import Tensor,matmul,softmax,dropout

def scaled_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Optional[Tensor] = None,
scaling_factor: float = 1.0,
dropout_prob: float = 0.0
) -> Tuple[Tensor, Tensor]:

# 参考论文: https://arxiv.org/abs/1706.03762 (Attention is All You Need)
#1 计算注意力分数
attn_scores = (query @ key.mT) * scaling_factor

#2 应用注意力掩码
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_scores = attn_scores + causal_mask


#3 softmax归一化
attn_weights = softmax(attn_scores, dim=-1)


#4 可选的dropout
if dropout_prob > 0.0:
attn_weights = dropout(attn_weights, p=dropout_prob)

#5 注意力加权值
attn_output = matmul(attn_weights, value)

return attn_output, attn_weights
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from deepx.nn.modules import Module
from deepx import cat
from deepx.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
from deepx.utils import Config

# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class LlamaRotaryEmbedding(Module):
def __init__(self,config:Config):
Expand Down Expand Up @@ -46,7 +47,7 @@ def forward(self, x, position_ids):


# 计算频率
freqs = (inv_freq_expanded @ position_ids_expanded).T
freqs = (inv_freq_expanded @ position_ids_expanded).mT
# 拼接频率
emb = cat((freqs, freqs), dim=-1)
# 计算余弦和正弦
Expand Down
34 changes: 34 additions & 0 deletions front/py/deepx/nn/modules/transformer/grouped_query_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional
from deepx import Tensor, Module
from .attention import scaled_dot_product_attention

def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# 经简化,去掉了分布式配置,去掉attention的配置。交给IR自动替换flashattention,后续的组件自动处理


def grouped_query_attention(
module: Module,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Optional[Tensor],
scaling_factor: float,
dropout_prob: float = 0.0,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

return scaled_dot_product_attention(
query, key, value,
attention_mask=attention_mask,
scaling_factor=scaling_factor,
dropout_prob=dropout_prob
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,82 @@
from typing import Optional,Tuple
from deepx.nn.modules import Module,Linear,Embedding
from deepx import Tensor
from deepx.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from deepx.transformer.models.llama.attention import LlamaAttention
from deepx.transformer.models.llama.mlp import LlamaMLP
from deepx.transformer.models.llama.normalization import LlamaRMSNorm
from deepx.transformer.models.llama.embedding import LlamaRotaryEmbedding

from deepx import Tensor,cat
from front.py.deepx.nn.modules.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from deepx.nn.modules.mlp import LlamaMLP
from deepx.nn.modules.norm import LlamaRMSNorm
from deepx.nn.modules.transformer import LlamaRotaryEmbedding

def rotate_half(x:Tensor):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return cat((-x2, x1,), dim=-1)

def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed



class LlamaAttention(Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True

self.q_proj = Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)


def forward(
self,
hidden_states: Tensor,
position_embeddings: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor]
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)


attn_output, attn_weights =GQA(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout
)

attn_output = attn_output.reshape(*input_shape, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

class LlamaDecoderLayer(Module):
def __init__(self, config:dict, layer_idx: int):
Expand Down
11 changes: 7 additions & 4 deletions front/py/deepx/tensor/changeshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ def permute_(self,dimorder:tuple[int,...])->Tensor:
return self

@tensor_method
def transpose(self,out:Union[Tensor,str]=''):
def transpose(self,dim0:int,dim1:int,out:Union[Tensor,str]=''):
assert isinstance(out,str) or isinstance(out,Tensor)
assert isinstance(dim0,int) and isinstance(dim1,int)
from deepx.nn.functional import transpose as transpose_func
result=transpose_func(self,out)
result=transpose_func(self,dim0,dim1,out)
return result

@tensor_method
def transpose_(self):
def transpose_(self,dim0:int,dim1:int):
assert isinstance(dim0,int) and isinstance(dim1,int)
assert isinstance(dim0,int) and isinstance(dim1,int)
from deepx.nn.functional import transpose as transpose_func
transpose_func(self,self)
transpose_func(self,dim0,dim1,self)
return self

# broadcast_to==broadcastTo==expand
Expand Down
2 changes: 1 addition & 1 deletion front/py/deepx/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def concat(cls,shapes:tuple,dim:int)->tuple[int,...]:
return tuple(outshape)

@classmethod
def matmul(cls,shape:tuple[int],other:tuple[int])->tuple[int]:
def matmul(cls,shape:tuple[int,...],other:tuple[int,...])->tuple[int,...]:
if len(shape)<2 or len(other)<2:
raise ValueError(f"matmul: self.ndimension()<2 or other.ndimension()<2")
if len(shape)!=len(other):
Expand Down
4 changes: 2 additions & 2 deletions front/py/deepx/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def __getitem__(self, idx):

#shape操作
@property
def T(self) -> str:
return self.transpose()
def mT(self) -> str:
return self.transpose(-1,-2)

# 打印
@staticmethod
Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
Loading
Loading