diff --git a/front/py/deepx/nn/modules/norm/__init__.py b/front/py/deepx/nn/modules/norm/__init__.py new file mode 100644 index 00000000..f4e64e56 --- /dev/null +++ b/front/py/deepx/nn/modules/norm/__init__.py @@ -0,0 +1,6 @@ +from .t5layernorm import * + +__all__ = [ + "T5LayerNorm","RMSNorm", +] + diff --git a/front/py/deepx/nn/modules/norm/normalization.py b/front/py/deepx/nn/modules/norm/normalization.py deleted file mode 100644 index 867f3db4..00000000 --- a/front/py/deepx/nn/modules/norm/normalization.py +++ /dev/null @@ -1,22 +0,0 @@ -from deepx.nn.modules import Module -from deepx import Tensor,ones,rsqrt -# RMSNorm -# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py -# 数学公式 -class LlamaRMSNorm(Module): - def __init__(self, hidden_size:int, eps:float=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight=ones((hidden_size,)) - self.register_parameter("weight",self.weight) - self.variance_epsilon = eps - def forward(self, hidden_states:Tensor): - variance = hidden_states.pow(2).mean((-1,), keepdim=True) - hidden_states = hidden_states * rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - \ No newline at end of file diff --git a/front/py/deepx/nn/modules/norm/t5layernorm.py b/front/py/deepx/nn/modules/norm/t5layernorm.py new file mode 100644 index 00000000..085621c4 --- /dev/null +++ b/front/py/deepx/nn/modules/norm/t5layernorm.py @@ -0,0 +1,28 @@ + +from deepx.nn.modules import Module +from deepx import Tensor,ones,rsqrt + +# 论文 https://arxiv.org/abs/1910.07467 +# 来自 https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +class T5LayerNorm( Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = ones((hidden_size,)) + self.register_parameter("weight", self.weight) + self.variance_epsilon = eps + + def forward(self, x:Tensor): + xtype=x.dtype + # layer norm should always be calculated in float32 + variance = x.float().pow(2).mean(-1, keepdim=True) + x = x*rsqrt(variance + self.variance_epsilon) + return (self.weight * x).todtype(xtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +RMSNorm = T5LayerNorm \ No newline at end of file diff --git a/front/py/deepx/nn/modules/transformer/llama/modeling_llama.py b/front/py/deepx/nn/modules/transformer/llama/modeling_llama.py index ef30ba5c..81350051 100644 --- a/front/py/deepx/nn/modules/transformer/llama/modeling_llama.py +++ b/front/py/deepx/nn/modules/transformer/llama/modeling_llama.py @@ -1,8 +1,8 @@ from typing import Optional,Tuple,Union from deepx.nn.modules import Module,ModuleList,Linear,Embedding from deepx import Tensor,cat,arange -from front.py.deepx.nn.modules.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from deepx.nn.modules.mlp import MLP +from deepx.nn.modules.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from deepx.nn.modules.mlp import GatedMLP from deepx.nn.modules.norm import RMSNorm from deepx.nn.modules.transformer import LlamaRotaryEmbedding,apply_rotary_pos_emb,grouped_query_attention as GQA from deepx.utils.config import Config @@ -75,7 +75,7 @@ def __init__(self, config:dict, layer_idx: int): self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) - self.mlp = MLP(config) + self.mlp = GatedMLP(config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)