Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions lir/persistence.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
from os import PathLike
from pathlib import Path

from lir.aggregation import Aggregation, AggregationData
Expand Down Expand Up @@ -26,12 +27,30 @@ def save_model(path: Path, model: LRSystem) -> None:
class SaveModel(Aggregation):
"""Write the model to a file."""

def __init__(self, path: Path):
self.path = path
def __init__(self, output_dir: Path, filename: PathLike | str = 'model.pkl') -> None:
"""
Initialize the aggregation object.

The model is saved as a pickle file, in a file named `filename`, that is written to a subdirectory of
`output_dir`, that is created for each run.

If `filename` is an absolute path, or if `filename` is relative to `output_dir`, then the model is saved to this
file as-is, instead of to a file in a newly created subdirectory.

:param output_dir: the directory where the model should be written
:param filename: the filename to be created for the model
"""
self.output_dir = output_dir
self.filename = Path(filename)

def report(self, data: AggregationData) -> None:
"""Write the trained LR system model to file."""
save_model(self.path, data.lrsystem)
"""Create a directory for the run and write the trained LR system model to file."""
if self.filename.is_absolute() or self.filename.is_relative_to(self.output_dir):
save_model(self.filename, data.lrsystem)
else:
dirname = self.output_dir / data.run_name if data.run_name else self.output_dir
dirname.mkdir(parents=True, exist_ok=True)
save_model(dirname / self.filename, data.lrsystem)


@config_parser
Expand All @@ -44,4 +63,4 @@ def parse_save_model(config: ContextAwareDict, output_dir: Path) -> SaveModel:
"""
filename = pop_field(config, 'filename', default='model.pkl', validate=str)
check_is_empty(config)
return SaveModel(output_dir / filename)
return SaveModel(output_dir, filename)