diff --git a/lir/persistence.py b/lir/persistence.py index a395b21..1841fd8 100644 --- a/lir/persistence.py +++ b/lir/persistence.py @@ -1,4 +1,5 @@ import pickle +from os import PathLike from pathlib import Path from lir.aggregation import Aggregation, AggregationData @@ -26,12 +27,30 @@ def save_model(path: Path, model: LRSystem) -> None: class SaveModel(Aggregation): """Write the model to a file.""" - def __init__(self, path: Path): - self.path = path + def __init__(self, output_dir: Path, filename: PathLike | str = 'model.pkl') -> None: + """ + Initialize the aggregation object. + + The model is saved as a pickle file, in a file named `filename`, that is written to a subdirectory of + `output_dir`, that is created for each run. + + If `filename` is an absolute path, or if `filename` is relative to `output_dir`, then the model is saved to this + file as-is, instead of to a file in a newly created subdirectory. + + :param output_dir: the directory where the model should be written + :param filename: the filename to be created for the model + """ + self.output_dir = output_dir + self.filename = Path(filename) def report(self, data: AggregationData) -> None: - """Write the trained LR system model to file.""" - save_model(self.path, data.lrsystem) + """Create a directory for the run and write the trained LR system model to file.""" + if self.filename.is_absolute() or self.filename.is_relative_to(self.output_dir): + save_model(self.filename, data.lrsystem) + else: + dirname = self.output_dir / data.run_name if data.run_name else self.output_dir + dirname.mkdir(parents=True, exist_ok=True) + save_model(dirname / self.filename, data.lrsystem) @config_parser @@ -44,4 +63,4 @@ def parse_save_model(config: ContextAwareDict, output_dir: Path) -> SaveModel: """ filename = pop_field(config, 'filename', default='model.pkl', validate=str) check_is_empty(config) - return SaveModel(output_dir / filename) + return SaveModel(output_dir, filename)