Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions nnf/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nnf.layers.base import Layer
from nnf.losses.base import Loss
from nnf.optimizers.base import Optimizer
from nnf.utils import LAYER_CLASSES
from nnf.utils import LAYER_REGISTERY

"""
This module defines a neural network Model class that combines layers and
Expand Down Expand Up @@ -401,16 +401,16 @@ def set_model_attrs(self, attrs : Dict):
Args:
attrs (dict): A dictionary of attribute names and values to set.
If a value is a dictionary with keys "type" and "attrs",
it initializes a layer object from `LAYER_CLASSES` and sets its parameters.
it initializes a layer object from `LAYER_REGISTERY` and sets its parameters.
"""

for key, val in attrs.items():
if isinstance(val, dict):
obj = LAYER_CLASSES[val["type"]]()
obj = LAYER_REGISTERY[val["type"]]()
obj.set_params(val["attrs"])
setattr(self, key, obj)
elif isinstance(val, str) and val in LAYER_CLASSES:
setattr(self, key, LAYER_CLASSES[val]())
elif isinstance(val, str) and val in LAYER_REGISTERY:
setattr(self, key, LAYER_REGISTERY[val]())
else:
setattr(self, key, val)

Expand Down Expand Up @@ -508,13 +508,13 @@ def load_model(file_path : str):

attrs = layer["attrs"] if layer["attrs"] else {}

obj : Layer = LAYER_CLASSES[layer_type]()
obj : Layer = LAYER_REGISTERY[layer_type]()
obj.set_params(attrs)

layers.append(obj)

model_attrs = {
key: LAYER_CLASSES[val]() if isinstance(val, str) and val in LAYER_CLASSES else val
key: LAYER_REGISTERY[val]() if isinstance(val, str) and val in LAYER_REGISTERY else val
for key, val in model_attrs.items()
}

Expand Down
4 changes: 2 additions & 2 deletions nnf/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from nnf.utils.layer_config import LAYER_CLASSES
from nnf.utils.layer_config import LAYER_REGISTERY

__all__ = ['LAYER_CLASSES']
__all__ = ['LAYER_REGISTERY']
2 changes: 1 addition & 1 deletion nnf/utils/layer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from nnf.optimizers import GradientDescent
from nnf.optimizers import Momentum

LAYER_CLASSES = {
LAYER_REGISTERY = {
"Dense" : Dense,
"ReLU" : ReLU,
"LeakyRelU" : LeakyReLU,
Expand Down