diff --git a/src/common/lightgbm_utils.py b/src/common/lightgbm_utils.py index d490806a..29f2f057 100644 --- a/src/common/lightgbm_utils.py +++ b/src/common/lightgbm_utils.py @@ -5,7 +5,12 @@ This classes provide help to integrate lightgbm """ import lightgbm +import numpy as np import logging +from typing import List +import threading +import time +import traceback class LightGBMCallbackHandler(): """ This class handles LightGBM callbacks for recording metrics. """ @@ -47,3 +52,208 @@ def callback(self, env: lightgbm.callback.CallbackEnv) -> None: value=result, step=env.iteration # provide iteration as step in mlflow ) + + +class LightGBMDistributedCallbackHandler(): + """ This class handles LightGBM callbacks for recording metrics. """ + def __init__(self, metrics_logger, mpi_comm, world_size=1, world_rank=0): + """Constructor + + Args: + metrics_logger (common.metrics.MetricsLogger): class to log metrics using MLFlow + mpi_comm (MPI.COMM_WORLD): communicator + world_size (int): mpi world size + world_rank (int): mpi world rank of this node + """ + self.recording_thread = DistributedMetricCollectionThread(metrics_logger, mpi_comm, world_size=world_size, world_rank=world_rank) + self.recording_thread.start() + self.logger = logging.getLogger(__name__) + + def finalize(self): + """Asks internal thread to finalize""" + # do one last report + self.recording_thread.aggregate_and_report_loop() + + # set status to kill and join + self.recording_thread.killed = True + self.recording_thread.join() + + def callback(self, env: lightgbm.callback.CallbackEnv) -> None: + """Callback method to collect metrics produced by LightGBM. + + See https://lightgbm.readthedocs.io/en/latest/_modules/lightgbm/callback.html + """ + # let's record in the object for future use + self.recording_thread.send_distributed_metric(env) + self.logger.info("End of callback") + + +class DistributedMetricCollectionThread(threading.Thread): + """ This class handles MPI communication of LightGBM callback metrics. + NOTE: We needed to put this in a thread because having callback() + do the recv/send directly was interacting with LightGBM's own MPI communication somehow. + """ + COMM_TAG_METRIC = 209834 # "random tag" + + def __init__(self, metrics_logger, mpi_comm, world_size=1, world_rank=0): + """Constructor + + Args: + metrics_logger (common.metrics.MetricsLogger): class to log metrics using MLFlow + mpi_comm (MPI.COMM_WORLD): communicator + world_size (int): mpi world size + world_rank (int): mpi world rank of this node + """ + threading.Thread.__init__(self) + self.killed = False # flag, set to True to kill from the inside + + self.logger = logging.getLogger(__name__) + self.metrics_logger = metrics_logger + + # internal sync storage + self.distributed_metrics = {} + self.record_lock = threading.Lock() + self.send_queue = [] + self.send_lock = threading.Lock() + + # MPI communication + self.mpi_comm = mpi_comm + self.world_size = world_size + self.world_rank = world_rank + + + ##################### + ### RUN FUNCTIONS ### + ##################### + + def run_head(self): + """Run function for node 0 only""" + while not(self.killed): + time.sleep(1) + + # collect everything from other nodes into internal record + for i in range(1, self.world_size): + self.logger.info(f"Probing metric from node {i}") + + try: + if self.mpi_comm.iprobe(source=i, tag=DistributedMetricCollectionThread.COMM_TAG_METRIC): # non-blocking + self.logger.info(f"Collecting metric from node {i}") + remote_node_metrics = self.mpi_comm.recv(source=i, tag=DistributedMetricCollectionThread.COMM_TAG_METRIC) # blocking + else: + self.logger.info(f"NO metric from node {i}") + continue + except BaseException: + self.logger.warning(f"Exception while listening to other nodes:\n{traceback.format_exc()}") + + self.record_distributed_metric(i, remote_node_metrics) + + # record node_0's own metrics in internal storage + with self.send_lock: + while self.send_queue: + entry = self.send_queue.pop() + self.record_distributed_metric(0, entry) + + # then aggregate whatever is in the internal record + self.aggregate_and_report_loop() + + def run_worker(self): + """Run function for all other nodes""" + while not(self.killed): + time.sleep(1) + # all other nodes send to node_0 + with self.send_lock: + while self.send_queue: + entry = self.send_queue.pop() + self.logger.info(f"Reporting metric back to node 0: {entry}") + self.mpi_comm.isend(entry, 0, tag=DistributedMetricCollectionThread.COMM_TAG_METRIC) # non-blocking + + def run(self): + """Main function of the thread""" + if self.world_rank == 0: + self.run_head() + else: + self.run_worker() + + ################### + ### SEND / RECV ### + ################### + + def send_distributed_metric(self, env: lightgbm.callback.CallbackEnv): + """Stores a metric report in the internal queue + to be sent by thread using MPI""" + + if self.world_rank == 0: # node_0 also record as mlflow + # loop on all the evaluation results tuples + for data_name, eval_name, result, _ in env.evaluation_result_list: + # log each as a distinct metric + self.metrics_logger.log_metric( + key=f"node_0/{data_name}.{eval_name}", + value=result, + step=env.iteration # provide iteration as step in mlflow + ) + + self.logger.info(f"Queueing metric to send to node_0: iteration={env.iteration}") + with self.send_lock: + # filtering out what we don't need to send + # in particular, we don't want to send the model! + self.send_queue.append({ + "iteration":env.iteration, + "evaluation_result_list":env.evaluation_result_list + }) + + def record_distributed_metric(self, node, report): + """Records a metric report internally to node 0""" + self.logger.info(f"Recorded metric from node {node}: {report}") + with self.record_lock: + iteration = report['iteration'] + if iteration not in self.distributed_metrics: + self.distributed_metrics[iteration] = {} + self.distributed_metrics[iteration][node] = report + + + ################## + ### PROCESSING ### + ################## + + def aggregate_and_report_task(self, key: str, iteration: int, eval_name: str, results: List[float]): + # TODO: devise aggregation method per eval_name + self.metrics_logger.log_metric( + key=key, + value=np.mean(results), + step=iteration # provide iteration as step in mlflow + ) + self.metrics_logger.log_metric( + key=key+"_min", + value=np.min(results), + step=iteration # provide iteration as step in mlflow + ) + self.metrics_logger.log_metric( + key=key+"_max", + value=np.max(results), + step=iteration # provide iteration as step in mlflow + ) + + def aggregate_and_report_loop(self): + aggregation_tasks = {} + + with self.record_lock: + for iteration in list(self.distributed_metrics.keys()): + if len(self.distributed_metrics[iteration]) < self.world_size: + continue + + # loop on all the evaluation results tuples + for node_id, node_metrics in self.distributed_metrics[iteration].items(): + for data_name, eval_name, result, _ in node_metrics['evaluation_result_list']: + key = f"{data_name}.{eval_name}" + if key not in aggregation_tasks: + # record name of metric for aggregation method + aggregation_tasks[key] = (iteration, eval_name, []) + + # add value in the list + aggregation_tasks[key][2].append(result) + + # once done, remove the data from the "queue" + del self.distributed_metrics[iteration] + + for key, (iteration, eval_name, results) in aggregation_tasks.items(): + self.aggregate_and_report_task(key, iteration, eval_name, results) diff --git a/src/common/metrics.py b/src/common/metrics.py index 909761f6..2d88cbba 100644 --- a/src/common/metrics.py +++ b/src/common/metrics.py @@ -76,7 +76,7 @@ def log_metric(self, key, value, step=None): key = self._remove_non_allowed_chars(key) - self._logger.debug(f"mlflow[session={self._session_name}].log_metric({key},{value})") + self._logger.debug(f"mlflow[session={self._session_name}].log_metric({key},{value},step={step})") # NOTE: there's a limit to the name of a metric if len(key) > 50: key = key[:50] diff --git a/src/scripts/training/lightgbm_python/train.py b/src/scripts/training/lightgbm_python/train.py index d9d3f84c..e6f8242a 100644 --- a/src/scripts/training/lightgbm_python/train.py +++ b/src/scripts/training/lightgbm_python/train.py @@ -14,6 +14,9 @@ import lightgbm from collections import namedtuple +import lightgbm +import mpi4py + # Add the right path to PYTHONPATH # so that you can import from common.* COMMON_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) @@ -25,7 +28,7 @@ # useful imports from common from common.components import RunnableScript from common.io import get_all_files -from common.lightgbm_utils import LightGBMCallbackHandler +from common.lightgbm_utils import LightGBMDistributedCallbackHandler from common.distributed import MultiNodeScript class LightGBMPythonMpiTrainingScript(MultiNodeScript): @@ -182,10 +185,12 @@ def run(self, args, logger, metrics_logger, unknown_args): # figure out the lgbm params from cli args + mpi config lgbm_params = self.load_lgbm_params_from_cli(args, mpi_config) - # create a handler for the metrics callbacks - callbacks_handler = LightGBMCallbackHandler( + # create a handler + callbacks_handler = LightGBMDistributedCallbackHandler( metrics_logger=metrics_logger, - metrics_prefix=f"node_{mpi_config.world_rank}/" + mpi_comm = self.mpi_config.mpi_comm, + world_rank=self.mpi_config.world_rank, + world_size=self.mpi_config.world_size ) # make sure the output argument exists @@ -242,6 +247,9 @@ def run(self, args, logger, metrics_logger, unknown_args): logger.info(f"Writing model in {args.export_model}") booster.save_model(args.export_model) + # finalize all remaining metrics + callbacks_handler.finalize() + def get_arg_parser(parser=None): """ To ensure compatibility with shrike unit tests """