Skip to content
Merged

py: #82

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions front/py/deepx/nn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .module import Module, Sequential
from .module import Module
from .container import Sequential, ModuleList
from .linear import Linear
from .sparse import Embedding

__all__ = [
"Module",
"Linear",
"Sequential",
"Sequential","ModuleList",
"Embedding",
"Linear",
]
313 changes: 313 additions & 0 deletions front/py/deepx/nn/modules/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
# 这个代码直接copy自pytorch https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
# 其中实现了什么,我本人也没仔细研究,如果有问题请咨询AI或者查看pytorch的文档

from __future__ import annotations

import operator
from collections import abc as container_abcs, OrderedDict
from itertools import chain, islice
from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Self

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping

from .module import Module

__all__ = [
"Sequential",
"ModuleList",
]

T = TypeVar("T", bound=Module)
_V = TypeVar("_V")


def _addindent(s_, numSpaces):
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s

class Sequential(Module):
# 参考 https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
_modules: dict[str, Module] # type: ignore[assignment]

@overload
def __init__(self, *args: Module) -> None: ...

@overload
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...

def __init__(self, *args):
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)

def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V:
"""Get the idx-th item of the iterator."""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError(f"index {idx} is out of range")
idx %= size
return next(islice(iterator, idx, None))

def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, Module]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)

def __setitem__(self, idx: int, module: Module) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)

def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
# To preserve numbering
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))


def __len__(self) -> int:
return len(self._modules)

def __add__(self, other) -> Sequential:
if isinstance(other, Sequential):
ret = Sequential()
for layer in self:
ret.append(layer)
for layer in other:
ret.append(layer)
return ret
else:
raise ValueError(
"add operator supports only objects "
f"of Sequential class, but {str(type(other))} is given."
)

def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v

def __iadd__(self, other) -> Self:
if isinstance(other, Sequential):
offset = len(self)
for i, module in enumerate(other):
self.add_module(str(i + offset), module)
return self
else:
raise ValueError(
"add operator supports only objects "
f"of Sequential class, but {str(type(other))} is given."
)

def __mul__(self, other: int) -> Sequential:
if not isinstance(other, int):
raise TypeError(
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
)
elif other <= 0:
raise ValueError(
f"Non-positive multiplication factor {other} for {type(self)}"
)
else:
combined = Sequential()
offset = 0
for _ in range(other):
for module in self:
combined.add_module(str(offset), module)
offset += 1
return combined

def __rmul__(self, other: int) -> Sequential:
return self.__mul__(other)

def __imul__(self, other: int) -> Self:
if not isinstance(other, int):
raise TypeError(
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
)
elif other <= 0:
raise ValueError(
f"Non-positive multiplication factor {other} for {type(self)}"
)
else:
len_original = len(self)
offset = len(self)
for _ in range(other - 1):
for i in range(len_original):
self.add_module(str(i + offset), self._modules[str(i)])
offset += len_original
return self

def __dir__(self) -> list[str]:
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys

def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())

def forward(self, input):
for module in self:
input = module(input)
return input

def append(self, module: Module) -> Self:
self.add_module(str(len(self)), module)
return self

def insert(self, index: int, module: Module) -> Self:
if not isinstance(module, Module):
raise AssertionError(f"module should be of type: {Module}")
n = len(self._modules)
if not (-n <= index <= n):
raise IndexError(f"Index out of range: {index}")
if index < 0:
index += n
for i in range(n, index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
return self

def extend(self, sequential: Iterable[Module]) -> Self:
for layer in sequential:
self.append(layer)
return self

class ModuleList(Module):
_modules: dict[str, Module] # type: ignore[assignment]

def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
super().__init__()
if modules is not None:
self += modules

def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f"index {idx} is out of range")
if idx < 0:
idx += len(self)
return str(idx)

@overload
def __getitem__(self, idx: slice) -> ModuleList: ...

@overload
def __getitem__(self, idx: int) -> Module: ...

def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]:
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]

def __setitem__(self, idx: int, module: Module) -> None:
idx = self._get_abs_string_index(idx)
return setattr(self, str(idx), module)

def __delitem__(self, idx: Union[int, slice]) -> None:
if isinstance(idx, slice):
for k in range(len(self._modules))[idx]:
delattr(self, str(k))
else:
delattr(self, self._get_abs_string_index(idx))
# To preserve numbering, self._modules is being reconstructed with modules after deletion
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))

def __len__(self) -> int:
return len(self._modules)

def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())

def __iadd__(self, modules: Iterable[Module]) -> Self:
return self.extend(modules)

def __add__(self, other: Iterable[Module]) -> ModuleList:
combined = ModuleList()
for i, module in enumerate(chain(self, other)):
combined.add_module(str(i), module)
return combined

def __repr__(self) -> str:
"""Return a custom repr for ModuleList that compresses repeated module representations."""
list_of_reprs = [repr(item) for item in self]
if len(list_of_reprs) == 0:
return self._get_name() + "()"

start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue

start_end_indices.append([i, i])
repeated_blocks.append(r)

lines = []
main_str = self._get_name() + "("
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr

if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"

local_repr = _addindent(local_repr, 2)
lines.append(local_repr)

main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str

def __dir__(self) -> list[str]:
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys

def insert(self, index: int, module: Module) -> None:
for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module

def append(self, module: Module) -> Self:
self.add_module(str(len(self)), module)
return self

def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v

def extend(self, modules: Iterable[Module]) -> Self:
if not isinstance(modules, container_abcs.Iterable):
raise TypeError(
"ModuleList.extend should be called with an "
"iterable, but got " + type(modules).__name__
)
offset = len(self)
for i, module in enumerate(modules):
self.add_module(str(offset + i), module)
return self
5 changes: 5 additions & 0 deletions front/py/deepx/nn/modules/mlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .gatedmlp import *

__all__ = [
"GatedMLP",
]
5 changes: 5 additions & 0 deletions front/py/deepx/nn/modules/mlp/actfn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from deepx.nn.functional import swish as swish_fn

ACT2FN={
"silu":swish_fn,
}
26 changes: 26 additions & 0 deletions front/py/deepx/nn/modules/mlp/gatedmlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from deepx.nn.modules import Module,Linear
from .actfn import ACT2FN

class GatedMLP(Module):
def __init__(self, config:dict):
super().__init__()
# 输入层大小
self.hidden_size = config.hidden_size
# 中间层大小
self.intermediate_size = config.intermediate_size
#门控投影层
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
#上投影层
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
#下投影层
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
#激活函数
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
act = self.act_fn(gate)
out = act * up
out = self.down_proj(out)
return out
Loading
Loading