diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 50c34365d8..8933491e1e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -52,6 +52,7 @@ below: - Ben Fitzpatrick (Met Office, UK) - Tom Gale (Bureau of Meteorology, Australia) - Sam Griffiths (Met Office, UK) + - Luke Hoffmann (Bureau of Meteorology, Australia) - Ben Hooper (Met Office, UK) - Aaron Hopkinson (Met Office, UK) - Kathryn Howard (Met Office, UK) diff --git a/improver/calibration/__init__.py b/improver/calibration/__init__.py index 972c51f9ae..9efde58ccb 100644 --- a/improver/calibration/__init__.py +++ b/improver/calibration/__init__.py @@ -68,6 +68,25 @@ def __init__(self): ) +def treelite_packages_available(): + """Return True if treelite packages are available, False otherwise.""" + try: + import tl2cgen # noqa: F401 + import treelite # noqa: F401 + except ModuleNotFoundError: + return False + return True + + +def lightgbm_package_available(): + """Return True if LightGBM package is available, False otherwise.""" + try: + import lightgbm # noqa: F401 + except ModuleNotFoundError: + return False + return True + + def split_forecasts_and_truth( cubes: List[Cube], truth_attribute: str ) -> Tuple[Cube, Cube, Optional[Cube]]: diff --git a/improver/calibration/rainforest_calibration.py b/improver/calibration/rainforest_calibration.py index 0cbba8223d..3dd1b8ee23 100644 --- a/improver/calibration/rainforest_calibration.py +++ b/improver/calibration/rainforest_calibration.py @@ -25,6 +25,10 @@ from numpy import ndarray from improver import PostProcessingPlugin +from improver.calibration import ( + lightgbm_package_available, + treelite_packages_available, +) from improver.constants import MINUTES_IN_HOUR, SECONDS_IN_MINUTE from improver.ensemble_copula_coupling.utilities import ( get_bounds_of_distribution, @@ -39,25 +43,6 @@ Model = Literal["lightgbm_model", "treelite_model"] -def treelite_packages_available(): - """Return True if treelite packages are available, False otherwise.""" - try: - import tl2cgen # noqa: F401 - import treelite # noqa: F401 - except ModuleNotFoundError: - return False - return True - - -def lightgbm_package_available(): - """Return True if LightGBM package is available, False otherwise.""" - try: - import lightgbm # noqa: F401 - except ModuleNotFoundError: - return False - return True - - class ModelFileNotFoundError(Exception): """Used when the path to a treelite/LightGBM model object is invalid.""" diff --git a/improver/calibration/rainforest_compilation.py b/improver/calibration/rainforest_compilation.py new file mode 100644 index 0000000000..5a88040ef7 --- /dev/null +++ b/improver/calibration/rainforest_compilation.py @@ -0,0 +1,73 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +"""RainForests model compilation plugin.""" + +from improver import BasePlugin +from improver.calibration import ( + treelite_packages_available, +) + +LIGHTGBM_EXTENSION = ".txt" +TREELITE_EXTENSION = ".so" + + +class CompileRainForestsModel(BasePlugin): + """Class to compile RainForests tree models""" + + def __init__(self, toolchain="gcc", verbose=False, parallel_comp=0): + """Initialise the options used when compiling models. + + Args: + toolchain (str): + Toolchain to use for Treelite model compilation. + 'gcc' (default), 'msvc', 'clang' or a specific variation of clang or gcc + (e.g. 'gcc-7'). + verbose (bool): + Print verbose output during compilation + parallel_comp (int): + Enables parallel compilation to reduce time and memory consumption. + Value is the number of processes to use. + Defaults to 0 (no parallel compilation) + """ + + self.treelite_available = treelite_packages_available() + if not self.treelite_available: + raise ModuleNotFoundError("Could not find TreeLite module") + + self.toolchain = toolchain + self.verbose = verbose + self.treelight_params = {"parallel_comp": parallel_comp, "quantize": 1} + + def process(self, model_file, output_dir): + """Compile a lightgbm model with Treelite. + + Args: + model_file (pathlib.Path): + Path to LightGBM Booster file. + output_dir (pathlib.Path): + Directory where the compiled Treelite predictor file will be created. + """ + + import tl2cgen + import treelite + + # Input validation + if model_file.suffix.lower() != LIGHTGBM_EXTENSION: + raise ValueError(f"Input path must have the extension {LIGHTGBM_EXTENSION}") + if not output_dir.is_dir(): + raise ValueError("Output path must be a directory") + + output_filepath = output_dir / f"{model_file.stem}{TREELITE_EXTENSION}" + + model = treelite.frontend.load_lightgbm_model(model_file) + + tl2cgen.export_lib( + model, + libpath=output_filepath, + toolchain=self.toolchain, + verbose=self.verbose, + params=self.treelight_params, + ) diff --git a/improver/calibration/rainforest_training.py b/improver/calibration/rainforest_training.py new file mode 100644 index 0000000000..75e6c56918 --- /dev/null +++ b/improver/calibration/rainforest_training.py @@ -0,0 +1,75 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +"""RainForests model training plugin.""" + +from improver import BasePlugin +from improver.calibration import ( + lightgbm_package_available, +) + + +class TrainRainForestsModel(BasePlugin): + lightgbm_params = { + "objective": "binary", + "num_leaves": 5, + "num_boost_round": 10, + "verbose": -1, + "seed": 0, + } + + def __init__(self, training_data, observation_column, training_columns): + """Initialise the options used when compiling models. + + Args: + training_data (pandas.DataFrame): + Combined data set used to train models. + observation_column (str): + The column in the data set to be trained for. + training_columns (List(str)): + Set of columns from the data set to be trained from. + """ + self.lightgbm_available = lightgbm_package_available() + if not self.lightgbm_available: + raise ModuleNotFoundError("Could not find LightGBM module") + + expected_columns = training_columns + [observation_column] + + # Check all specified columns exist in the data. + for col in expected_columns: + if col not in training_data: + raise KeyError(f"Column '{col}' not found in training data.") + + # Check the observation column is not also a training column. + if observation_column in training_columns: + raise KeyError( + f"Observation column '{observation_column}' appears in training data." + ) + + self.observation_column = observation_column + self.training_columns = training_columns + + # Keep only the columns relevant for training. + self.training_data = training_data[expected_columns] + + def process(self, threshold, output_path): + """Train a model for a particular threshold. + + Args: + threshold (float): + Threshold for which the observation column is trained. + output_path (str or Path): + The model will be exported to this file path. + """ + import lightgbm + + threshold_met = ( + self.training_data[self.observation_column] >= threshold + ).astype(int) + dataset = lightgbm.Dataset(self.training_data, label=threshold_met) + + model = lightgbm.train(self.lightgbm_params, dataset) + + model.save_model(output_path) diff --git a/improver_tests/calibration/rainforests_training/__init__.py b/improver_tests/calibration/rainforests_training/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/improver_tests/calibration/rainforests_training/conftest.py b/improver_tests/calibration/rainforests_training/conftest.py new file mode 100644 index 0000000000..bf3e2cc1cd --- /dev/null +++ b/improver_tests/calibration/rainforests_training/conftest.py @@ -0,0 +1,89 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. +import sys + +import pytest + +from improver.calibration import lightgbm_package_available, treelite_packages_available + +from ..rainforests_calibration.conftest import ( + deterministic_features, + deterministic_forecast, + dummy_lightgbm_models, + ensemble_features, + ensemble_forecast, + lead_times, + prepare_dummy_training_data, + thresholds, +) + +_ = ( + deterministic_features, + deterministic_forecast, + dummy_lightgbm_models, + ensemble_features, + ensemble_forecast, + lead_times, + prepare_dummy_training_data, + thresholds, +) + +dummy_lightgbm_models = dummy_lightgbm_models + + +@pytest.fixture(params=[True, False]) +def lightgbm_available(request, monkeypatch): + """Make lightgbm module available or unavailable""" + + available = request.param and lightgbm_package_available() + if not available: + monkeypatch.setitem(sys.modules, "lightgbm", None) + return available + + +@pytest.fixture(params=[True, False]) +def treelite_available(request, monkeypatch): + """Make treelite module available or unavailable""" + + available = request.param and treelite_packages_available() + if not available: + monkeypatch.setitem(sys.modules, "treelite", None) + return available + + +@pytest.fixture +def deterministic_training_data( + deterministic_features, deterministic_forecast, lead_times +): + training_data, fcst_column, observation_column, training_columns = ( + prepare_dummy_training_data( + deterministic_features, deterministic_forecast, lead_times + ) + ) + + # This data contains several lead times. Filter the data to one leadtime. + lead_time = 24 + curr_training_data = training_data.loc[ + training_data["lead_time_hours"] == lead_time + ] + + return curr_training_data, observation_column, training_columns + + +@pytest.fixture +def rainforests_model_files(dummy_lightgbm_models, tmp_path): + """Export some LightGBM Boosters to file""" + + tree_models, lead_times, thresholds = dummy_lightgbm_models + + output_dir = tmp_path / "models" + output_dir.mkdir(exist_ok=True) + + def saved_path(lead_time, threshold): + path = output_dir / f"model_{lead_time:0}_{threshold:06.4f}.txt" + tree_models[lead_time, threshold].save_model(path) + return path + + return [saved_path(l, t) for l in lead_times for t in thresholds] diff --git a/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py b/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py new file mode 100644 index 0000000000..b5de5d8ff7 --- /dev/null +++ b/improver_tests/calibration/rainforests_training/test_CompileRainForestsModel.py @@ -0,0 +1,73 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +import shutil +from pathlib import Path + +import pytest + +from improver.calibration.rainforest_compilation import ( + CompileRainForestsModel, +) + +tl2cgen = pytest.importorskip("tl2cgen") +treelite = pytest.importorskip("treelite") + + +def test__init__(treelite_available): + """Test class is created if treelight libraries are available. + Test class is not created if treelight libraries not available.""" + + if treelite_available: + expected_class = "CompileRainForestsModel" + result = CompileRainForestsModel() + assert type(result).__name__ == expected_class + else: + with pytest.raises(ModuleNotFoundError): + result = CompileRainForestsModel() + + +def test_process(rainforests_model_files, tmp_path): + """Test models are compiled.""" + + compiler = CompileRainForestsModel(parallel_comp=8) + + output_dir = Path(tmp_path) / "compiled" + output_dir.mkdir(exist_ok=True) + + for model_file in rainforests_model_files: + compiler.process(model_file, output_dir) + + assert Path.exists(output_dir / f"{model_file.stem}.so") + + +def test_process_wrong_extension(rainforests_model_files, tmp_path): + """Test models are not compiled when model file has unexpected extension.""" + + compiler = CompileRainForestsModel(parallel_comp=8) + + output_dir = Path(tmp_path) / "compiled" + output_dir.mkdir(exist_ok=True) + + model_file = Path(rainforests_model_files[0]) + wrong_extension = model_file.with_suffix(".dat") + + shutil.copyfile(model_file, wrong_extension) + + with pytest.raises(ValueError): + compiler.process(wrong_extension, output_dir) + + +def test_process_invalid_dir(rainforests_model_files, tmp_path): + """Test models are not compiled when output dir doesn't exist.""" + + compiler = CompileRainForestsModel(parallel_comp=8) + + model_file = Path(rainforests_model_files[0]) + missing_dir = Path(tmp_path) / "missing" + assert not missing_dir.is_dir() + + with pytest.raises(ValueError): + compiler.process(model_file, missing_dir) diff --git a/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py b/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py new file mode 100644 index 0000000000..8086bb89da --- /dev/null +++ b/improver_tests/calibration/rainforests_training/test_TrainRainForestsModel.py @@ -0,0 +1,91 @@ +# (C) Crown Copyright, Met Office. All rights reserved. +# +# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license. +# See LICENSE in the root of the repository for full licensing details. + +from pathlib import Path + +import pytest + +from improver.calibration.rainforest_training import ( + TrainRainForestsModel, +) + +lightgbm = pytest.importorskip("lightgbm") + + +def test__init__lightgmb_available(lightgbm_available, deterministic_training_data): + """Test class is created if lightgbm library is available. + Test class is not created if lightgbm library not available.""" + + training_data, observation_column, training_columns = deterministic_training_data + + if lightgbm_available: + expected_class = "TrainRainForestsModel" + result = TrainRainForestsModel( + training_data, observation_column, training_columns + ) + assert type(result).__name__ == expected_class + else: + with pytest.raises(ModuleNotFoundError): + result = TrainRainForestsModel( + training_data, observation_column, training_columns + ) + + +def test__init__(deterministic_training_data): + """Test class is created with training data.""" + training_data, observation_column, training_columns = deterministic_training_data + + result = TrainRainForestsModel(training_data, observation_column, training_columns) + assert result.training_columns == training_columns + assert result.observation_column == observation_column + + +def test__init__missing_obs_column(deterministic_training_data): + """Test class creation fails when observation column isn't present in the training data.""" + training_data, observation_column, training_columns = deterministic_training_data + + dummy_obs_column = "dummy_obs_column" + + with pytest.raises(KeyError) as e: + TrainRainForestsModel(training_data, dummy_obs_column, training_columns) + assert dummy_obs_column in str(e) + + +def test__init__missing_train_column(deterministic_training_data): + """Test class creation fails when one of the training columns isn't present in the training data.""" + training_data, observation_column, training_columns = deterministic_training_data + + dummy_train_column = "dummy_train_column" + training_columns.append(dummy_train_column) + + with pytest.raises(KeyError) as e: + TrainRainForestsModel(training_data, observation_column, training_columns) + assert dummy_train_column in str(e) + + +def test__init__obs_column_is_train_column(deterministic_training_data): + """Test class creation fails when the observation column is one of the training columns.""" + training_data, observation_column, training_columns = deterministic_training_data + + dummy_obs_column = training_columns[2] + + with pytest.raises(KeyError) as e: + TrainRainForestsModel(training_data, dummy_obs_column, training_columns) + assert dummy_obs_column not in str(e) + + +def test_process(thresholds, deterministic_training_data, tmp_path): + """Test lightgbm models are created at specified path.""" + + training_data, observation_column, training_columns = deterministic_training_data + + trainer = TrainRainForestsModel(training_data, observation_column, training_columns) + + result_path = tmp_path / "output.txt" + + threshold = thresholds[0] + trainer.process(threshold, result_path) + + assert Path.exists(result_path)