From 909bd3274ed582c60a6eb21ed0cbf9b7d3d4667c Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Tue, 18 Nov 2025 14:47:28 +1100 Subject: [PATCH 1/6] Refactor module checks into common area --- CONTRIBUTING.md | 1 + improver/calibration/__init__.py | 19 +++++++++++++++ .../calibration/rainforest_calibration.py | 23 ++++--------------- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 50c34365d8..8933491e1e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -52,6 +52,7 @@ below: - Ben Fitzpatrick (Met Office, UK) - Tom Gale (Bureau of Meteorology, Australia) - Sam Griffiths (Met Office, UK) + - Luke Hoffmann (Bureau of Meteorology, Australia) - Ben Hooper (Met Office, UK) - Aaron Hopkinson (Met Office, UK) - Kathryn Howard (Met Office, UK) diff --git a/improver/calibration/__init__.py b/improver/calibration/__init__.py index 972c51f9ae..9efde58ccb 100644 --- a/improver/calibration/__init__.py +++ b/improver/calibration/__init__.py @@ -68,6 +68,25 @@ def __init__(self): ) +def treelite_packages_available(): + """Return True if treelite packages are available, False otherwise.""" + try: + import tl2cgen # noqa: F401 + import treelite # noqa: F401 + except ModuleNotFoundError: + return False + return True + + +def lightgbm_package_available(): + """Return True if LightGBM package is available, False otherwise.""" + try: + import lightgbm # noqa: F401 + except ModuleNotFoundError: + return False + return True + + def split_forecasts_and_truth( cubes: List[Cube], truth_attribute: str ) -> Tuple[Cube, Cube, Optional[Cube]]: diff --git a/improver/calibration/rainforest_calibration.py b/improver/calibration/rainforest_calibration.py index 0cbba8223d..3dd1b8ee23 100644 --- a/improver/calibration/rainforest_calibration.py +++ b/improver/calibration/rainforest_calibration.py @@ -25,6 +25,10 @@ from numpy import ndarray from improver import PostProcessingPlugin +from improver.calibration import ( + lightgbm_package_available, + treelite_packages_available, +) from improver.constants import MINUTES_IN_HOUR, SECONDS_IN_MINUTE from improver.ensemble_copula_coupling.utilities import ( get_bounds_of_distribution, @@ -39,25 +43,6 @@ Model = Literal["lightgbm_model", "treelite_model"] -def treelite_packages_available(): - """Return True if treelite packages are available, False otherwise.""" - try: - import tl2cgen # noqa: F401 - import treelite # noqa: F401 - except ModuleNotFoundError: - return False - return True - - -def lightgbm_package_available(): - """Return True if LightGBM package is available, False otherwise.""" - try: - import lightgbm # noqa: F401 - except ModuleNotFoundError: - return False - return True - - class ModelFileNotFoundError(Exception): """Used when the path to a treelite/LightGBM model object is invalid.""" From 5a67ea0e3611125c7890457c6f6aab946c391773 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Mon, 24 Nov 2025 13:52:02 +1100 Subject: [PATCH 2/6] Plugins for training and compiling Rainforests models --- improver/calibration/rainforest_compiler.py | 73 +++++++++++++++ improver/calibration/rainforest_training.py | 70 +++++++++++++++ .../rainforests_training/__init__.py | 0 .../rainforests_training/conftest.py | 89 +++++++++++++++++++ .../test_CompileRainForestsCalibration.py | 42 +++++++++ .../test_TrainRainForestsCalibration.py | 83 +++++++++++++++++ 6 files changed, 357 insertions(+) create mode 100644 improver/calibration/rainforest_compiler.py create mode 100644 improver/calibration/rainforest_training.py create mode 100644 improver_tests/calibration/rainforests_training/__init__.py create mode 100644 improver_tests/calibration/rainforests_training/conftest.py create mode 100644 improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py create mode 100644 improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py diff --git a/improver/calibration/rainforest_compiler.py b/improver/calibration/rainforest_compiler.py new file mode 100644 index 0000000000..2bf791623d --- /dev/null +++ b/improver/calibration/rainforest_compiler.py @@ -0,0 +1,73 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +"""RainForests model compilation plugin.""" + +from improver import BasePlugin +from improver.calibration import ( + treelite_packages_available, +) + +LIGHTGBM_EXTENSION = ".txt" +TREELITE_EXTENSION = ".so" + + +class RainforestsCompiler(BasePlugin): + """Class to compile Rainforests tree models""" + + def __init__(self, toolchain="gcc", verbose=False, parallel_comp=0): + """Initialise the options used when compiling models. + + Args: + toolchain (str): + Toolchain to use for Treelite model compilation. + 'gcc' (default), 'msvc', 'clang' or a specific variation of clang or gcc + (e.g. 'gcc-7'). + verbose (bool): + Print verbose output during compilation + parallel_comp (int): + Enables parallel compilation to reduce time and memory consumption. + Value is the number of processes to use. + Defaults to 0 (no parallel compilation) + """ + + self.treelite_available = treelite_packages_available() + if not self.treelite_available: + raise ModuleNotFoundError("Could not find TreeLite module") + + self.toolchain = toolchain + self.verbose = verbose + self.treelight_params = {"parallel_comp": parallel_comp, "quantize": 1} + + def process(self, model_file, output_dir): + """Compile a lightgbm model with Treelite. + + Args: + model_file (pathlib.Path): + Path to LightGBM Booster file. + output_dir (pathlib.Path): + Directory where the compiled Treelite predictor file will be created. + """ + + import tl2cgen + import treelite + + # Input validation + if model_file.suffix.lower() != LIGHTGBM_EXTENSION: + raise ValueError(f"Input path must have the extension {LIGHTGBM_EXTENSION}") + if not output_dir.is_dir(): + raise ValueError("Output path must be a directory") + + output_filepath = output_dir / f"{model_file.stem}{TREELITE_EXTENSION}" + + model = treelite.frontend.load_lightgbm_model(model_file) + + tl2cgen.export_lib( + model, + libpath=output_filepath, + toolchain=self.toolchain, + verbose=self.verbose, + params=self.treelight_params, + ) diff --git a/improver/calibration/rainforest_training.py b/improver/calibration/rainforest_training.py new file mode 100644 index 0000000000..959762aabb --- /dev/null +++ b/improver/calibration/rainforest_training.py @@ -0,0 +1,70 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +"""RainForests model training plugin.""" + +from improver import BasePlugin +from improver.calibration import ( + lightgbm_package_available, +) + + +class TrainRainForestsCalibration(BasePlugin): + lightgbm_params = { + "objective": "binary", + "num_leaves": 5, + "num_boost_round": 10, + "verbose": -1, + "seed": 0, + } + + def __init__(self, training_data, observation_column, training_columns): + """Initialise the options used when compiling models. + + Args: + training_data (pandas.DataFrame): + Combined data set used to train models + observation_column (str): + The column in the data set to be used + training_columns (List(str)): + Set of columns from the data set to be used as training data. + """ + self.lightgbm_available = lightgbm_package_available() + if not self.lightgbm_available: + raise ModuleNotFoundError("Could not find LightGBM module") + + self.observation_column = observation_column + self.training_columns = training_columns + + expected_columns = training_columns + [observation_column] + for col in expected_columns: + if col not in training_data: + raise KeyError(f"Column {col} not found in training data") + + self.training_data = training_data[expected_columns] + + def process( + self, threshold, output_path=None + ): + """Train a model for a particular threshold. + + Args: + threshold (float): + Threshold for which the observation column is trained. + output_path (str or Path): + If provided, the model will be exported to this file path. + """ + import lightgbm + + threshold_met = (self.training_data[self.observation_column] >= threshold).astype( + int + ) + dataset = lightgbm.Dataset(self.training_data, label=threshold_met) + + model = lightgbm.train(self.lightgbm_params, dataset) + if output_path: + model.save_model(output_path) + + return model.model_to_string() diff --git a/improver_tests/calibration/rainforests_training/__init__.py b/improver_tests/calibration/rainforests_training/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/improver_tests/calibration/rainforests_training/conftest.py b/improver_tests/calibration/rainforests_training/conftest.py new file mode 100644 index 0000000000..bf3e2cc1cd --- /dev/null +++ b/improver_tests/calibration/rainforests_training/conftest.py @@ -0,0 +1,89 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +import sys + +import pytest + +from improver.calibration import lightgbm_package_available, treelite_packages_available + +from ..rainforests_calibration.conftest import ( + deterministic_features, + deterministic_forecast, + dummy_lightgbm_models, + ensemble_features, + ensemble_forecast, + lead_times, + prepare_dummy_training_data, + thresholds, +) + +_ = ( + deterministic_features, + deterministic_forecast, + dummy_lightgbm_models, + ensemble_features, + ensemble_forecast, + lead_times, + prepare_dummy_training_data, + thresholds, +) + +dummy_lightgbm_models = dummy_lightgbm_models + + +@pytest.fixture(params=[True, False]) +def lightgbm_available(request, monkeypatch): + """Make lightgbm module available or unavailable""" + + available = request.param and lightgbm_package_available() + if not available: + monkeypatch.setitem(sys.modules, "lightgbm", None) + return available + + +@pytest.fixture(params=[True, False]) +def treelite_available(request, monkeypatch): + """Make treelite module available or unavailable""" + + available = request.param and treelite_packages_available() + if not available: + monkeypatch.setitem(sys.modules, "treelite", None) + return available + + +@pytest.fixture +def deterministic_training_data( + deterministic_features, deterministic_forecast, lead_times +): + training_data, fcst_column, observation_column, training_columns = ( + prepare_dummy_training_data( + deterministic_features, deterministic_forecast, lead_times + ) + ) + + # This data contains several lead times. Filter the data to one leadtime. + lead_time = 24 + curr_training_data = training_data.loc[ + training_data["lead_time_hours"] == lead_time + ] + + return curr_training_data, observation_column, training_columns + + +@pytest.fixture +def rainforests_model_files(dummy_lightgbm_models, tmp_path): + """Export some LightGBM Boosters to file""" + + tree_models, lead_times, thresholds = dummy_lightgbm_models + + output_dir = tmp_path / "models" + output_dir.mkdir(exist_ok=True) + + def saved_path(lead_time, threshold): + path = output_dir / f"model_{lead_time:0}_{threshold:06.4f}.txt" + tree_models[lead_time, threshold].save_model(path) + return path + + return [saved_path(l, t) for l in lead_times for t in thresholds] diff --git a/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py b/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py new file mode 100644 index 0000000000..3d681b2d3c --- /dev/null +++ b/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py @@ -0,0 +1,42 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +from pathlib import Path + +import pytest + +from improver.calibration.rainforest_compiler import ( + RainforestsCompiler, +) + +tl2cgen = pytest.importorskip("tl2cgen") +treelite = pytest.importorskip("treelite") + + +def test__init__(treelite_available, tmp_path): + """Test class is created if treelight libraries are available. + Test class is not created if treelight libraries not available.""" + + if treelite_available: + expected_class = "RainforestsCompiler" + result = RainforestsCompiler() + assert type(result).__name__ == expected_class + else: + with pytest.raises(ModuleNotFoundError): + result = RainforestsCompiler() + + +def test_process(rainforests_model_files, tmp_path): + """Test models are compiled.""" + + compiler = RainforestsCompiler(parallel_comp=8) + + output_dir = Path(tmp_path) / 'compiled' + output_dir.mkdir(exist_ok=True) + + for model_file in rainforests_model_files: + compiler.process(model_file, output_dir) + + assert Path.exists(output_dir / f"{model_file.stem}.so") diff --git a/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py b/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py new file mode 100644 index 0000000000..72e94b5069 --- /dev/null +++ b/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py @@ -0,0 +1,83 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +from pathlib import Path + +import pytest + +from improver.calibration.rainforest_training import ( + TrainRainForestsCalibration, +) + +lightgbm = pytest.importorskip("lightgbm") + + +def test__init__lightgmb_available(lightgbm_available, deterministic_training_data): + """Test class is created if lightgbm library is available. + Test class is not created if lightgbm library not available.""" + + training_data, observation_column, training_columns = ( + deterministic_training_data + ) + + if lightgbm_available: + expected_class = "TrainRainForestsCalibration" + result = TrainRainForestsCalibration( + training_data, observation_column, training_columns + ) + assert type(result).__name__ == expected_class + else: + with pytest.raises(ModuleNotFoundError): + result = TrainRainForestsCalibration( + training_data, observation_column, training_columns + ) + + +def test__init__(deterministic_training_data): + """Test class is created with training data.""" + training_data, observation_column, training_columns = ( + deterministic_training_data + ) + + result = TrainRainForestsCalibration( + training_data, observation_column, training_columns + ) + assert result.training_columns == training_columns + assert result.observation_column == observation_column + + +def test_process(thresholds, deterministic_training_data): + """Test lightgbm models are created.""" + + training_data, observation_column, training_columns = ( + deterministic_training_data + ) + + threshold = thresholds[0] + + trainer = TrainRainForestsCalibration( + training_data, observation_column, training_columns + ) + result = trainer.process(threshold) + assert isinstance(result, str) + + +def test_process_with_path(thresholds, deterministic_training_data, tmp_path): + """Test lightgbm models are created at specified path.""" + + training_data, observation_column, training_columns = ( + deterministic_training_data + ) + + trainer = TrainRainForestsCalibration( + training_data, observation_column, training_columns + ) + + result_path = tmp_path / "output.txt" + + threshold = thresholds[0] + trainer.process(threshold, result_path) + + assert Path.exists(result_path) From a0a284f2ef442e08ffa79424a5c69eebf77729a1 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Tue, 25 Nov 2025 10:33:50 +1100 Subject: [PATCH 3/6] Add some testing for column selection --- improver/calibration/rainforest_training.py | 31 +++++++----- .../test_CompileRainForestsCalibration.py | 2 +- .../test_TrainRainForestsCalibration.py | 50 ++++++++++++++----- 3 files changed, 58 insertions(+), 25 deletions(-) diff --git a/improver/calibration/rainforest_training.py b/improver/calibration/rainforest_training.py index 959762aabb..bdb2b7ab29 100644 --- a/improver/calibration/rainforest_training.py +++ b/improver/calibration/rainforest_training.py @@ -25,9 +25,9 @@ def __init__(self, training_data, observation_column, training_columns): Args: training_data (pandas.DataFrame): - Combined data set used to train models + Combined data set used to train models. observation_column (str): - The column in the data set to be used + The column in the data set to be used. training_columns (List(str)): Set of columns from the data set to be used as training data. """ @@ -35,19 +35,26 @@ def __init__(self, training_data, observation_column, training_columns): if not self.lightgbm_available: raise ModuleNotFoundError("Could not find LightGBM module") - self.observation_column = observation_column - self.training_columns = training_columns - expected_columns = training_columns + [observation_column] + + # Check all specified columns exist in the data. for col in expected_columns: if col not in training_data: - raise KeyError(f"Column {col} not found in training data") + raise KeyError(f"Column '{col}' not found in training data.") + + # Check the observation column is not also a training column. + if observation_column in training_columns: + raise KeyError( + f"Observation column '{observation_column}' appears in training columns." + ) + + self.observation_column = observation_column + self.training_columns = training_columns + # Keep only the columns relevant for training. self.training_data = training_data[expected_columns] - def process( - self, threshold, output_path=None - ): + def process(self, threshold, output_path=None): """Train a model for a particular threshold. Args: @@ -58,9 +65,9 @@ def process( """ import lightgbm - threshold_met = (self.training_data[self.observation_column] >= threshold).astype( - int - ) + threshold_met = ( + self.training_data[self.observation_column] >= threshold + ).astype(int) dataset = lightgbm.Dataset(self.training_data, label=threshold_met) model = lightgbm.train(self.lightgbm_params, dataset) diff --git a/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py b/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py index 3d681b2d3c..b91debeafb 100644 --- a/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py +++ b/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py @@ -33,7 +33,7 @@ def test_process(rainforests_model_files, tmp_path): compiler = RainforestsCompiler(parallel_comp=8) - output_dir = Path(tmp_path) / 'compiled' + output_dir = Path(tmp_path) / "compiled" output_dir.mkdir(exist_ok=True) for model_file in rainforests_model_files: diff --git a/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py b/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py index 72e94b5069..cd4e5f9b5c 100644 --- a/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py +++ b/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py @@ -18,9 +18,7 @@ def test__init__lightgmb_available(lightgbm_available, deterministic_training_da """Test class is created if lightgbm library is available. Test class is not created if lightgbm library not available.""" - training_data, observation_column, training_columns = ( - deterministic_training_data - ) + training_data, observation_column, training_columns = deterministic_training_data if lightgbm_available: expected_class = "TrainRainForestsCalibration" @@ -37,9 +35,7 @@ def test__init__lightgmb_available(lightgbm_available, deterministic_training_da def test__init__(deterministic_training_data): """Test class is created with training data.""" - training_data, observation_column, training_columns = ( - deterministic_training_data - ) + training_data, observation_column, training_columns = deterministic_training_data result = TrainRainForestsCalibration( training_data, observation_column, training_columns @@ -48,12 +44,44 @@ def test__init__(deterministic_training_data): assert result.observation_column == observation_column +def test__init__missing_obs_column(deterministic_training_data): + """Test class creation fails when observation column isn't present in the training data.""" + training_data, observation_column, training_columns = deterministic_training_data + + dummy_obs_column = "dummy_obs_column" + + with pytest.raises(KeyError) as e: + TrainRainForestsCalibration(training_data, dummy_obs_column, training_columns) + assert dummy_obs_column in str(e) + + +def test__init__missing_train_column(deterministic_training_data): + """Test class creation fails when one of the training columns isn't present in the training data.""" + training_data, observation_column, training_columns = deterministic_training_data + + dummy_train_column = "dummy_train_column" + training_columns.append(dummy_train_column) + + with pytest.raises(KeyError) as e: + TrainRainForestsCalibration(training_data, observation_column, training_columns) + assert dummy_train_column in str(e) + + +def test__init__obs_column_is_train_column(deterministic_training_data): + """Test class creation fails when the observation column is one of the training columns.""" + training_data, observation_column, training_columns = deterministic_training_data + + dummy_obs_column = training_columns[2] + + with pytest.raises(KeyError) as e: + TrainRainForestsCalibration(training_data, dummy_obs_column, training_columns) + assert dummy_obs_column not in str(e) + + def test_process(thresholds, deterministic_training_data): """Test lightgbm models are created.""" - training_data, observation_column, training_columns = ( - deterministic_training_data - ) + training_data, observation_column, training_columns = deterministic_training_data threshold = thresholds[0] @@ -67,9 +95,7 @@ def test_process(thresholds, deterministic_training_data): def test_process_with_path(thresholds, deterministic_training_data, tmp_path): """Test lightgbm models are created at specified path.""" - training_data, observation_column, training_columns = ( - deterministic_training_data - ) + training_data, observation_column, training_columns = deterministic_training_data trainer = TrainRainForestsCalibration( training_data, observation_column, training_columns From 8bea5dd1aae1b0fbc62fc94ce62eac3d522607d1 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Tue, 25 Nov 2025 16:59:41 +1100 Subject: [PATCH 4/6] Standardise some naming and documentation --- ..._compiler.py => rainforest_compilation.py} | 4 +-- improver/calibration/rainforest_training.py | 16 ++++----- ...ion.py => test_CompileRainForestsModel.py} | 12 +++---- ...ation.py => test_TrainRainForestsModel.py} | 34 ++++++------------- 4 files changed, 25 insertions(+), 41 deletions(-) rename improver/calibration/{rainforest_compiler.py => rainforest_compilation.py} (96%) rename improver_tests/calibration/rainforests_training/{test_CompileRainForestsCalibration.py => test_CompileRainForestsModel.py} (79%) rename improver_tests/calibration/rainforests_training/{test_TrainRainForestsCalibration.py => test_TrainRainForestsModel.py} (73%) diff --git a/improver/calibration/rainforest_compiler.py b/improver/calibration/rainforest_compilation.py similarity index 96% rename from improver/calibration/rainforest_compiler.py rename to improver/calibration/rainforest_compilation.py index 2bf791623d..5a88040ef7 100644 --- a/improver/calibration/rainforest_compiler.py +++ b/improver/calibration/rainforest_compilation.py @@ -14,8 +14,8 @@ TREELITE_EXTENSION = ".so" -class RainforestsCompiler(BasePlugin): - """Class to compile Rainforests tree models""" +class CompileRainForestsModel(BasePlugin): + """Class to compile RainForests tree models""" def __init__(self, toolchain="gcc", verbose=False, parallel_comp=0): """Initialise the options used when compiling models. diff --git a/improver/calibration/rainforest_training.py b/improver/calibration/rainforest_training.py index bdb2b7ab29..75e6c56918 100644 --- a/improver/calibration/rainforest_training.py +++ b/improver/calibration/rainforest_training.py @@ -11,7 +11,7 @@ ) -class TrainRainForestsCalibration(BasePlugin): +class TrainRainForestsModel(BasePlugin): lightgbm_params = { "objective": "binary", "num_leaves": 5, @@ -27,9 +27,9 @@ def __init__(self, training_data, observation_column, training_columns): training_data (pandas.DataFrame): Combined data set used to train models. observation_column (str): - The column in the data set to be used. + The column in the data set to be trained for. training_columns (List(str)): - Set of columns from the data set to be used as training data. + Set of columns from the data set to be trained from. """ self.lightgbm_available = lightgbm_package_available() if not self.lightgbm_available: @@ -45,7 +45,7 @@ def __init__(self, training_data, observation_column, training_columns): # Check the observation column is not also a training column. if observation_column in training_columns: raise KeyError( - f"Observation column '{observation_column}' appears in training columns." + f"Observation column '{observation_column}' appears in training data." ) self.observation_column = observation_column @@ -54,14 +54,14 @@ def __init__(self, training_data, observation_column, training_columns): # Keep only the columns relevant for training. self.training_data = training_data[expected_columns] - def process(self, threshold, output_path=None): + def process(self, threshold, output_path): """Train a model for a particular threshold. Args: threshold (float): Threshold for which the observation column is trained. output_path (str or Path): - If provided, the model will be exported to this file path. + The model will be exported to this file path. """ import lightgbm @@ -71,7 +71,5 @@ def process(self, threshold, output_path=None): dataset = lightgbm.Dataset(self.training_data, label=threshold_met) model = lightgbm.train(self.lightgbm_params, dataset) - if output_path: - model.save_model(output_path) - return model.model_to_string() + model.save_model(output_path) diff --git a/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py b/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py similarity index 79% rename from improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py rename to improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py index b91debeafb..2546f28800 100644 --- a/improver_tests/calibration/rainforests_training/test_CompileRainForestsCalibration.py +++ b/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py @@ -8,30 +8,30 @@ import pytest from improver.calibration.rainforest_compiler import ( - RainforestsCompiler, + CompileRainForestsModel, ) tl2cgen = pytest.importorskip("tl2cgen") treelite = pytest.importorskip("treelite") -def test__init__(treelite_available, tmp_path): +def test__init__(treelite_available): """Test class is created if treelight libraries are available. Test class is not created if treelight libraries not available.""" if treelite_available: - expected_class = "RainforestsCompiler" - result = RainforestsCompiler() + expected_class = "CompileRainForestsModel" + result = CompileRainForestsModel() assert type(result).__name__ == expected_class else: with pytest.raises(ModuleNotFoundError): - result = RainforestsCompiler() + result = CompileRainForestsModel() def test_process(rainforests_model_files, tmp_path): """Test models are compiled.""" - compiler = RainforestsCompiler(parallel_comp=8) + compiler = CompileRainForestsModel(parallel_comp=8) output_dir = Path(tmp_path) / "compiled" output_dir.mkdir(exist_ok=True) diff --git a/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py b/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py similarity index 73% rename from improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py rename to improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py index cd4e5f9b5c..520d590de0 100644 --- a/improver_tests/calibration/rainforests_training/test_TrainRainForestsCalibration.py +++ b/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py @@ -8,7 +8,7 @@ import pytest from improver.calibration.rainforest_training import ( - TrainRainForestsCalibration, + TrainRainForestsModel, ) lightgbm = pytest.importorskip("lightgbm") @@ -21,14 +21,14 @@ def test__init__lightgmb_available(lightgbm_available, deterministic_training_da training_data, observation_column, training_columns = deterministic_training_data if lightgbm_available: - expected_class = "TrainRainForestsCalibration" - result = TrainRainForestsCalibration( + expected_class = "TrainRainForestsModel" + result = TrainRainForestsModel( training_data, observation_column, training_columns ) assert type(result).__name__ == expected_class else: with pytest.raises(ModuleNotFoundError): - result = TrainRainForestsCalibration( + result = TrainRainForestsModel( training_data, observation_column, training_columns ) @@ -37,7 +37,7 @@ def test__init__(deterministic_training_data): """Test class is created with training data.""" training_data, observation_column, training_columns = deterministic_training_data - result = TrainRainForestsCalibration( + result = TrainRainForestsModel( training_data, observation_column, training_columns ) assert result.training_columns == training_columns @@ -51,7 +51,7 @@ def test__init__missing_obs_column(deterministic_training_data): dummy_obs_column = "dummy_obs_column" with pytest.raises(KeyError) as e: - TrainRainForestsCalibration(training_data, dummy_obs_column, training_columns) + TrainRainForestsModel(training_data, dummy_obs_column, training_columns) assert dummy_obs_column in str(e) @@ -63,7 +63,7 @@ def test__init__missing_train_column(deterministic_training_data): training_columns.append(dummy_train_column) with pytest.raises(KeyError) as e: - TrainRainForestsCalibration(training_data, observation_column, training_columns) + TrainRainForestsModel(training_data, observation_column, training_columns) assert dummy_train_column in str(e) @@ -74,30 +74,16 @@ def test__init__obs_column_is_train_column(deterministic_training_data): dummy_obs_column = training_columns[2] with pytest.raises(KeyError) as e: - TrainRainForestsCalibration(training_data, dummy_obs_column, training_columns) + TrainRainForestsModel(training_data, dummy_obs_column, training_columns) assert dummy_obs_column not in str(e) -def test_process(thresholds, deterministic_training_data): - """Test lightgbm models are created.""" - - training_data, observation_column, training_columns = deterministic_training_data - - threshold = thresholds[0] - - trainer = TrainRainForestsCalibration( - training_data, observation_column, training_columns - ) - result = trainer.process(threshold) - assert isinstance(result, str) - - -def test_process_with_path(thresholds, deterministic_training_data, tmp_path): +def test_process(thresholds, deterministic_training_data, tmp_path): """Test lightgbm models are created at specified path.""" training_data, observation_column, training_columns = deterministic_training_data - trainer = TrainRainForestsCalibration( + trainer = TrainRainForestsModel( training_data, observation_column, training_columns ) From 2148d81fb5777319ece20ad471b06a987982baa0 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Tue, 25 Nov 2025 17:52:37 +1100 Subject: [PATCH 5/6] Fix import --- .../rainforests_training/test_CompileRainForestsModel.py | 2 +- .../rainforests_training/test_TrainRainForestsModel.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py b/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py index 2546f28800..24bd264c19 100644 --- a/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py +++ b/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py @@ -7,7 +7,7 @@ import pytest -from improver.calibration.rainforest_compiler import ( +from improver.calibration.rainforest_compilation import ( CompileRainForestsModel, ) diff --git a/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py b/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py index 520d590de0..8086bb89da 100644 --- a/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py +++ b/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py @@ -37,9 +37,7 @@ def test__init__(deterministic_training_data): """Test class is created with training data.""" training_data, observation_column, training_columns = deterministic_training_data - result = TrainRainForestsModel( - training_data, observation_column, training_columns - ) + result = TrainRainForestsModel(training_data, observation_column, training_columns) assert result.training_columns == training_columns assert result.observation_column == observation_column @@ -83,9 +81,7 @@ def test_process(thresholds, deterministic_training_data, tmp_path): training_data, observation_column, training_columns = deterministic_training_data - trainer = TrainRainForestsModel( - training_data, observation_column, training_columns - ) + trainer = TrainRainForestsModel(training_data, observation_column, training_columns) result_path = tmp_path / "output.txt" From fb7bce71e6f9173b8a2b8001902641dc8c3bdbfc Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Wed, 26 Nov 2025 10:35:15 +1100 Subject: [PATCH 6/6] Bump up test coverage --- .../test_CompileRainForestsModel.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py b/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py index 24bd264c19..b5de5d8ff7 100644 --- a/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py +++ b/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py @@ -3,6 +3,7 @@ # This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. # See LICENSE in the root of the repository for full licensing details. +import shutil from pathlib import Path import pytest @@ -40,3 +41,33 @@ def test_process(rainforests_model_files, tmp_path): compiler.process(model_file, output_dir) assert Path.exists(output_dir / f"{model_file.stem}.so") + + +def test_process_wrong_extension(rainforests_model_files, tmp_path): + """Test models are not compiled when model file has unexpected extension.""" + + compiler = CompileRainForestsModel(parallel_comp=8) + + output_dir = Path(tmp_path) / "compiled" + output_dir.mkdir(exist_ok=True) + + model_file = Path(rainforests_model_files[0]) + wrong_extension = model_file.with_suffix(".dat") + + shutil.copyfile(model_file, wrong_extension) + + with pytest.raises(ValueError): + compiler.process(wrong_extension, output_dir) + + +def test_process_invalid_dir(rainforests_model_files, tmp_path): + """Test models are not compiled when output dir doesn't exist.""" + + compiler = CompileRainForestsModel(parallel_comp=8) + + model_file = Path(rainforests_model_files[0]) + missing_dir = Path(tmp_path) / "missing" + assert not missing_dir.is_dir() + + with pytest.raises(ValueError): + compiler.process(model_file, missing_dir)