From 1285e5aab0ffda03842a31bd3110c6009f2ca42f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 13 Jun 2025 20:48:12 +0200 Subject: [PATCH] add out_dim to hparams explicitly --- chebai/models/base.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index fd02c6ce..7d6e124d 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional, Union, Iterable +from typing import Any, Dict, Iterable, Optional, Union import torch from lightning.pytorch.core.module import LightningModule @@ -49,6 +49,11 @@ def __init__( if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion + assert out_dim is not None, "out_dim must be specified" + assert input_dim is not None, "input_dim must be specified" + self.out_dim = out_dim + self.input_dim = input_dim + self.save_hyperparameters( ignore=[ "criterion", @@ -59,10 +64,8 @@ def __init__( ] ) - self.out_dim = out_dim - self.input_dim = input_dim - assert out_dim is not None, "out_dim must be specified" - assert input_dim is not None, "input_dim must be specified" + self.hparams["out_dim"] = out_dim + self.hparams["input_dim"] = input_dim if optimizer_kwargs: self.optimizer_kwargs = optimizer_kwargs