From f7abc4d1f3066db8c7ada7b1deda9fb08c73016f Mon Sep 17 00:00:00 2001 From: lipeng <734991033@qq.com> Date: Wed, 2 Jul 2025 23:48:51 +0800 Subject: [PATCH] py: 1.mlp 2.container:ModuleList,Sequential 3.transformer --- front/py/deepx/nn/modules/__init__.py | 8 +- front/py/deepx/nn/modules/container.py | 313 ++++++++++++++++++ front/py/deepx/nn/modules/mlp/__init__.py | 5 + front/py/deepx/nn/modules/mlp/actfn.py | 5 + front/py/deepx/nn/modules/mlp/gatedmlp.py | 26 ++ front/py/deepx/nn/modules/mlp/mlp.py | 34 +- .../deepx/nn/modules/transformer/__init__.py | 8 +- .../transformer/llama/modeling_llama.py | 80 ++--- .../transformer/modeling_rope_utils.py | 12 +- .../{embedding.py => rotary_embedding.py} | 14 +- 10 files changed, 424 insertions(+), 81 deletions(-) create mode 100644 front/py/deepx/nn/modules/container.py create mode 100644 front/py/deepx/nn/modules/mlp/__init__.py create mode 100644 front/py/deepx/nn/modules/mlp/actfn.py create mode 100644 front/py/deepx/nn/modules/mlp/gatedmlp.py rename front/py/deepx/nn/modules/transformer/{embedding.py => rotary_embedding.py} (86%) diff --git a/front/py/deepx/nn/modules/__init__.py b/front/py/deepx/nn/modules/__init__.py index bf433622..a9fee1b6 100644 --- a/front/py/deepx/nn/modules/__init__.py +++ b/front/py/deepx/nn/modules/__init__.py @@ -1,9 +1,11 @@ -from .module import Module, Sequential +from .module import Module +from .container import Sequential, ModuleList from .linear import Linear from .sparse import Embedding + __all__ = [ "Module", - "Linear", - "Sequential", + "Sequential","ModuleList", "Embedding", + "Linear", ] diff --git a/front/py/deepx/nn/modules/container.py b/front/py/deepx/nn/modules/container.py new file mode 100644 index 00000000..627f3b02 --- /dev/null +++ b/front/py/deepx/nn/modules/container.py @@ -0,0 +1,313 @@ +# 这个代码直接copy自pytorch https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py +# 其中实现了什么,我本人也没仔细研究,如果有问题请咨询AI或者查看pytorch的文档 + +from __future__ import annotations + +import operator +from collections import abc as container_abcs, OrderedDict +from itertools import chain, islice +from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Self + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping + +from .module import Module + +__all__ = [ + "Sequential", + "ModuleList", +] + +T = TypeVar("T", bound=Module) +_V = TypeVar("_V") + + +def _addindent(s_, numSpaces): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + +class Sequential(Module): + # 参考 https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py + _modules: dict[str, Module] # type: ignore[assignment] + + @overload + def __init__(self, *args: Module) -> None: ... + + @overload + def __init__(self, arg: OrderedDict[str, Module]) -> None: ... + + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V: + """Get the idx-th item of the iterator.""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError(f"index {idx} is out of range") + idx %= size + return next(islice(iterator, idx, None)) + + def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, Module]: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: Module) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: Union[slice, int]) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + + def __len__(self) -> int: + return len(self._modules) + + def __add__(self, other) -> Sequential: + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> Self: + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) + + def __mul__(self, other: int) -> Sequential: + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> Sequential: + return self.__mul__(other) + + def __imul__(self, other: int) -> Self: + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + + def __dir__(self) -> list[str]: + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def forward(self, input): + for module in self: + input = module(input) + return input + + def append(self, module: Module) -> Self: + self.add_module(str(len(self)), module) + return self + + def insert(self, index: int, module: Module) -> Self: + if not isinstance(module, Module): + raise AssertionError(f"module should be of type: {Module}") + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError(f"Index out of range: {index}") + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential: Iterable[Module]) -> Self: + for layer in sequential: + self.append(layer) + return self + +class ModuleList(Module): + _modules: dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + super().__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f"index {idx} is out of range") + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: slice) -> ModuleList: ... + + @overload + def __getitem__(self, idx: int) -> Module: ... + + def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: Module) -> None: + idx = self._get_abs_string_index(idx) + return setattr(self, str(idx), module) + + def __delitem__(self, idx: Union[int, slice]) -> None: + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + delattr(self, str(k)) + else: + delattr(self, self._get_abs_string_index(idx)) + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Iterable[Module]) -> Self: + return self.extend(modules) + + def __add__(self, other: Iterable[Module]) -> ModuleList: + combined = ModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def __repr__(self) -> str: + """Return a custom repr for ModuleList that compresses repeated module representations.""" + list_of_reprs = [repr(item) for item in self] + if len(list_of_reprs) == 0: + return self._get_name() + "()" + + start_end_indices = [[0, 0]] + repeated_blocks = [list_of_reprs[0]] + for i, r in enumerate(list_of_reprs[1:], 1): + if r == repeated_blocks[-1]: + start_end_indices[-1][1] += 1 + continue + + start_end_indices.append([i, i]) + repeated_blocks.append(r) + + lines = [] + main_str = self._get_name() + "(" + for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): + local_repr = f"({start_id}): {b}" # default repr + + if start_id != end_id: + n = end_id - start_id + 1 + local_repr = f"({start_id}-{end_id}): {n} x {b}" + + local_repr = _addindent(local_repr, 2) + lines.append(local_repr) + + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str + + def __dir__(self) -> list[str]: + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def insert(self, index: int, module: Module) -> None: + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + def append(self, module: Module) -> Self: + self.add_module(str(len(self)), module) + return self + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def extend(self, modules: Iterable[Module]) -> Self: + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleList.extend should be called with an " + "iterable, but got " + type(modules).__name__ + ) + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self \ No newline at end of file diff --git a/front/py/deepx/nn/modules/mlp/__init__.py b/front/py/deepx/nn/modules/mlp/__init__.py new file mode 100644 index 00000000..5aa435af --- /dev/null +++ b/front/py/deepx/nn/modules/mlp/__init__.py @@ -0,0 +1,5 @@ +from .gatedmlp import * + +__all__ = [ + "GatedMLP", +] \ No newline at end of file diff --git a/front/py/deepx/nn/modules/mlp/actfn.py b/front/py/deepx/nn/modules/mlp/actfn.py new file mode 100644 index 00000000..ab9aae81 --- /dev/null +++ b/front/py/deepx/nn/modules/mlp/actfn.py @@ -0,0 +1,5 @@ +from deepx.nn.functional import swish as swish_fn + +ACT2FN={ + "silu":swish_fn, +} diff --git a/front/py/deepx/nn/modules/mlp/gatedmlp.py b/front/py/deepx/nn/modules/mlp/gatedmlp.py new file mode 100644 index 00000000..5ef1fc5c --- /dev/null +++ b/front/py/deepx/nn/modules/mlp/gatedmlp.py @@ -0,0 +1,26 @@ +from deepx.nn.modules import Module,Linear +from .actfn import ACT2FN + +class GatedMLP(Module): + def __init__(self, config:dict): + super().__init__() + # 输入层大小 + self.hidden_size = config.hidden_size + # 中间层大小 + self.intermediate_size = config.intermediate_size + #门控投影层 + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + #上投影层 + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + #下投影层 + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + #激活函数 + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate = self.gate_proj(x) + up = self.up_proj(x) + act = self.act_fn(gate) + out = act * up + out = self.down_proj(out) + return out \ No newline at end of file diff --git a/front/py/deepx/nn/modules/mlp/mlp.py b/front/py/deepx/nn/modules/mlp/mlp.py index eefb5004..439dedfc 100644 --- a/front/py/deepx/nn/modules/mlp/mlp.py +++ b/front/py/deepx/nn/modules/mlp/mlp.py @@ -1,26 +1,22 @@ -from deepx.nn.functional import swish as swish_fn -from deepx.nn.modules import Module,Linear +from deepx.nn.modules import Module, Linear +from .actfn import ACT2FN -ACT2FN={ - "silu":swish_fn, -} - -class MLP(Module): - def __init__(self, config:dict): +class StandardMLP(Module): + def __init__(self, config: dict): super().__init__() # 输入层大小 - self.hidden_size = config.hidden_size + self.hidden_size = config.hidden_size # 中间层大小 - self.intermediate_size = config["intermediate_size"] - #门控投影层 - self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - #上投影层 - self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - #下投影层 - self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - #激活函数 + self.intermediate_size = config.intermediate_size + # 第一层线性 + self.fc1 = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + # 第二层线性 + self.fc2 = Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + # 激活函数 self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj \ No newline at end of file + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + return x \ No newline at end of file diff --git a/front/py/deepx/nn/modules/transformer/__init__.py b/front/py/deepx/nn/modules/transformer/__init__.py index bfab8802..f51d5e5d 100644 --- a/front/py/deepx/nn/modules/transformer/__init__.py +++ b/front/py/deepx/nn/modules/transformer/__init__.py @@ -1,8 +1,10 @@ -from .embedding import * +from .rotary_embedding import * from .attention import * +from .grouped_query_attention import * __all__ = [ - "scaled_dot_product_attention", - "LlamaRotaryEmbedding", + "scaled_dot_product_attention",#attention.py + "grouped_query_attention","repeat_kv",#grouped_query_attention.py + "apply_rotary_pos_emb","LlamaRotaryEmbedding",#rotary_embedding.py "rotate_half" ] \ No newline at end of file diff --git a/front/py/deepx/nn/modules/transformer/llama/modeling_llama.py b/front/py/deepx/nn/modules/transformer/llama/modeling_llama.py index c8aace79..ef30ba5c 100644 --- a/front/py/deepx/nn/modules/transformer/llama/modeling_llama.py +++ b/front/py/deepx/nn/modules/transformer/llama/modeling_llama.py @@ -1,22 +1,12 @@ -from typing import Optional,Tuple -from deepx.nn.modules import Module,Linear,Embedding -from deepx import Tensor,cat +from typing import Optional,Tuple,Union +from deepx.nn.modules import Module,ModuleList,Linear,Embedding +from deepx import Tensor,cat,arange from front.py.deepx.nn.modules.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from deepx.nn.modules.mlp import LlamaMLP -from deepx.nn.modules.norm import LlamaRMSNorm -from deepx.nn.modules.transformer import LlamaRotaryEmbedding +from deepx.nn.modules.mlp import MLP +from deepx.nn.modules.norm import RMSNorm +from deepx.nn.modules.transformer import LlamaRotaryEmbedding,apply_rotary_pos_emb,grouped_query_attention as GQA +from deepx.utils.config import Config -def rotate_half(x:Tensor): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return cat((-x2, x1,), dim=-1) - -def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1): - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed @@ -85,9 +75,9 @@ def __init__(self, config:dict, layer_idx: int): self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = MLP(config) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -142,7 +132,7 @@ def __init__(self, config:dict): self.layers = ModuleList( [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False @@ -184,13 +174,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( + cache_position = arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -244,17 +231,13 @@ def forward( def _update_causal_mask( self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, + attention_mask: Tensor, + input_tensor: Tensor, + cache_position: Tensor, past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. @@ -278,7 +261,7 @@ def _update_causal_mask( else: target_length = ( attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) + if isinstance(attention_mask, Tensor) else past_seen_tokens + sequence_length + 1 ) @@ -302,19 +285,18 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min + min_dtype = finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, + attention_mask: Tensor, sequence_length: int, target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, + dtype: dtype, + cache_position: Tensor, batch_size: int, **kwargs, ): @@ -323,7 +305,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: - attention_mask (`torch.Tensor`): + attention_mask (` Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): @@ -331,26 +313,26 @@ def _prepare_4d_causal_attention_mask_with_cache_position( target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): + dtype (` dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): + device (` device`): The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): + cache_position (` Tensor`): Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): + batch_size (` Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( + min_dtype = finfo(dtype).min + causal_mask = full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = triu(causal_mask, diagonal=1) + causal_mask *= arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/front/py/deepx/nn/modules/transformer/modeling_rope_utils.py b/front/py/deepx/nn/modules/transformer/modeling_rope_utils.py index af693b91..1c9185ea 100644 --- a/front/py/deepx/nn/modules/transformer/modeling_rope_utils.py +++ b/front/py/deepx/nn/modules/transformer/modeling_rope_utils.py @@ -5,8 +5,8 @@ def _compute_default_rope_parameters(config:Config=None,seq_len: Optional[int] = None, **rope_kwargs) -> Tuple[Tensor, float]: if len(rope_kwargs) > 0: - base = rope_kwargs["base"] - dim = rope_kwargs["dim"] + base = rope_kwargs.base + dim = rope_kwargs.dim elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 @@ -23,10 +23,10 @@ def _compute_llama3_parameters(config:Config,seq_len: Optional[int] = None,**rop # Gets the default RoPE parameters inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len, **rope_kwargs) - factor = config.rope_scaling["factor"] # `8` in the original implementation - low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation - high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation - old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + factor = config.rope_scaling.factor # `8` in the original implementation + low_freq_factor = config.rope_scaling.low_freq_factor # `1` in the original implementation + high_freq_factor = config.rope_scaling.high_freq_factor # `4` in the original implementation + old_context_len = config.rope_scaling.original_max_position_embeddings # `8192` in the original implementation low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/front/py/deepx/nn/modules/transformer/embedding.py b/front/py/deepx/nn/modules/transformer/rotary_embedding.py similarity index 86% rename from front/py/deepx/nn/modules/transformer/embedding.py rename to front/py/deepx/nn/modules/transformer/rotary_embedding.py index 3f3d3db8..275b557b 100644 --- a/front/py/deepx/nn/modules/transformer/embedding.py +++ b/front/py/deepx/nn/modules/transformer/rotary_embedding.py @@ -1,5 +1,5 @@ from deepx.nn.modules import Module -from deepx import cat +from deepx import cat,Tensor from .modeling_rope_utils import ROPE_INIT_FUNCTIONS from deepx.utils import Config @@ -58,3 +58,15 @@ def forward(self, x, position_ids): sin = sin * self.attention_scaling return cos.todtype(x.dtype), sin.todtype(x.dtype) + +def rotate_half(x:Tensor): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return cat((-x2, x1,), dim=-1) + +def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed \ No newline at end of file