Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions front/py/deepx/nn/modules/norm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .t5layernorm import *

__all__ = [
"T5LayerNorm","RMSNorm",
]

22 changes: 0 additions & 22 deletions front/py/deepx/nn/modules/norm/normalization.py

This file was deleted.

28 changes: 28 additions & 0 deletions front/py/deepx/nn/modules/norm/t5layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

from deepx.nn.modules import Module
from deepx import Tensor,ones,rsqrt

# 论文 https://arxiv.org/abs/1910.07467
# 来自 https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
class T5LayerNorm( Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
super().__init__()
self.weight = ones((hidden_size,))
self.register_parameter("weight", self.weight)
self.variance_epsilon = eps

def forward(self, x:Tensor):
xtype=x.dtype
# layer norm should always be calculated in float32
variance = x.float().pow(2).mean(-1, keepdim=True)
x = x*rsqrt(variance + self.variance_epsilon)
return (self.weight * x).todtype(xtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


RMSNorm = T5LayerNorm
6 changes: 3 additions & 3 deletions front/py/deepx/nn/modules/transformer/llama/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional,Tuple,Union
from deepx.nn.modules import Module,ModuleList,Linear,Embedding
from deepx import Tensor,cat,arange
from front.py.deepx.nn.modules.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from deepx.nn.modules.mlp import MLP
from deepx.nn.modules.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from deepx.nn.modules.mlp import GatedMLP
from deepx.nn.modules.norm import RMSNorm
from deepx.nn.modules.transformer import LlamaRotaryEmbedding,apply_rotary_pos_emb,grouped_query_attention as GQA
from deepx.utils.config import Config
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(self, config:dict, layer_idx: int):

self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)

self.mlp = MLP(config)
self.mlp = GatedMLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
Loading