Skip to content
This repository was archived by the owner on Apr 8, 2024. It is now read-only.
Draft
210 changes: 210 additions & 0 deletions src/common/lightgbm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
This classes provide help to integrate lightgbm
"""
import lightgbm
import numpy as np
import logging
from typing import List
import threading
import time
import traceback

class LightGBMCallbackHandler():
""" This class handles LightGBM callbacks for recording metrics. """
Expand Down Expand Up @@ -47,3 +52,208 @@ def callback(self, env: lightgbm.callback.CallbackEnv) -> None:
value=result,
step=env.iteration # provide iteration as step in mlflow
)


class LightGBMDistributedCallbackHandler():
""" This class handles LightGBM callbacks for recording metrics. """
def __init__(self, metrics_logger, mpi_comm, world_size=1, world_rank=0):
"""Constructor

Args:
metrics_logger (common.metrics.MetricsLogger): class to log metrics using MLFlow
mpi_comm (MPI.COMM_WORLD): communicator
world_size (int): mpi world size
world_rank (int): mpi world rank of this node
"""
self.recording_thread = DistributedMetricCollectionThread(metrics_logger, mpi_comm, world_size=world_size, world_rank=world_rank)
self.recording_thread.start()
self.logger = logging.getLogger(__name__)

def finalize(self):
"""Asks internal thread to finalize"""
# do one last report
self.recording_thread.aggregate_and_report_loop()

# set status to kill and join
self.recording_thread.killed = True
self.recording_thread.join()

def callback(self, env: lightgbm.callback.CallbackEnv) -> None:
"""Callback method to collect metrics produced by LightGBM.

See https://lightgbm.readthedocs.io/en/latest/_modules/lightgbm/callback.html
"""
# let's record in the object for future use
self.recording_thread.send_distributed_metric(env)
self.logger.info("End of callback")


class DistributedMetricCollectionThread(threading.Thread):
""" This class handles MPI communication of LightGBM callback metrics.
NOTE: We needed to put this in a thread because having callback()
do the recv/send directly was interacting with LightGBM's own MPI communication somehow.
"""
COMM_TAG_METRIC = 209834 # "random tag"

def __init__(self, metrics_logger, mpi_comm, world_size=1, world_rank=0):
"""Constructor

Args:
metrics_logger (common.metrics.MetricsLogger): class to log metrics using MLFlow
mpi_comm (MPI.COMM_WORLD): communicator
world_size (int): mpi world size
world_rank (int): mpi world rank of this node
"""
threading.Thread.__init__(self)
self.killed = False # flag, set to True to kill from the inside

self.logger = logging.getLogger(__name__)
self.metrics_logger = metrics_logger

# internal sync storage
self.distributed_metrics = {}
self.record_lock = threading.Lock()
self.send_queue = []
self.send_lock = threading.Lock()

# MPI communication
self.mpi_comm = mpi_comm
self.world_size = world_size
self.world_rank = world_rank


#####################
### RUN FUNCTIONS ###
#####################

def run_head(self):
"""Run function for node 0 only"""
while not(self.killed):
time.sleep(1)

# collect everything from other nodes into internal record
for i in range(1, self.world_size):
self.logger.info(f"Probing metric from node {i}")

try:
if self.mpi_comm.iprobe(source=i, tag=DistributedMetricCollectionThread.COMM_TAG_METRIC): # non-blocking
self.logger.info(f"Collecting metric from node {i}")
remote_node_metrics = self.mpi_comm.recv(source=i, tag=DistributedMetricCollectionThread.COMM_TAG_METRIC) # blocking
else:
self.logger.info(f"NO metric from node {i}")
continue
except BaseException:
self.logger.warning(f"Exception while listening to other nodes:\n{traceback.format_exc()}")

self.record_distributed_metric(i, remote_node_metrics)

# record node_0's own metrics in internal storage
with self.send_lock:
while self.send_queue:
entry = self.send_queue.pop()
self.record_distributed_metric(0, entry)

# then aggregate whatever is in the internal record
self.aggregate_and_report_loop()

def run_worker(self):
"""Run function for all other nodes"""
while not(self.killed):
time.sleep(1)
# all other nodes send to node_0
with self.send_lock:
while self.send_queue:
entry = self.send_queue.pop()
self.logger.info(f"Reporting metric back to node 0: {entry}")
self.mpi_comm.isend(entry, 0, tag=DistributedMetricCollectionThread.COMM_TAG_METRIC) # non-blocking

def run(self):
"""Main function of the thread"""
if self.world_rank == 0:
self.run_head()
else:
self.run_worker()

###################
### SEND / RECV ###
###################

def send_distributed_metric(self, env: lightgbm.callback.CallbackEnv):
"""Stores a metric report in the internal queue
to be sent by thread using MPI"""

if self.world_rank == 0: # node_0 also record as mlflow
# loop on all the evaluation results tuples
for data_name, eval_name, result, _ in env.evaluation_result_list:
# log each as a distinct metric
self.metrics_logger.log_metric(
key=f"node_0/{data_name}.{eval_name}",
value=result,
step=env.iteration # provide iteration as step in mlflow
)

self.logger.info(f"Queueing metric to send to node_0: iteration={env.iteration}")
with self.send_lock:
# filtering out what we don't need to send
# in particular, we don't want to send the model!
self.send_queue.append({
"iteration":env.iteration,
"evaluation_result_list":env.evaluation_result_list
})

def record_distributed_metric(self, node, report):
"""Records a metric report internally to node 0"""
self.logger.info(f"Recorded metric from node {node}: {report}")
with self.record_lock:
iteration = report['iteration']
if iteration not in self.distributed_metrics:
self.distributed_metrics[iteration] = {}
self.distributed_metrics[iteration][node] = report


##################
### PROCESSING ###
##################

def aggregate_and_report_task(self, key: str, iteration: int, eval_name: str, results: List[float]):
# TODO: devise aggregation method per eval_name
self.metrics_logger.log_metric(
key=key,
value=np.mean(results),
step=iteration # provide iteration as step in mlflow
)
self.metrics_logger.log_metric(
key=key+"_min",
value=np.min(results),
step=iteration # provide iteration as step in mlflow
)
self.metrics_logger.log_metric(
key=key+"_max",
value=np.max(results),
step=iteration # provide iteration as step in mlflow
)

def aggregate_and_report_loop(self):
aggregation_tasks = {}

with self.record_lock:
for iteration in list(self.distributed_metrics.keys()):
if len(self.distributed_metrics[iteration]) < self.world_size:
continue

# loop on all the evaluation results tuples
for node_id, node_metrics in self.distributed_metrics[iteration].items():
for data_name, eval_name, result, _ in node_metrics['evaluation_result_list']:
key = f"{data_name}.{eval_name}"
if key not in aggregation_tasks:
# record name of metric for aggregation method
aggregation_tasks[key] = (iteration, eval_name, [])

# add value in the list
aggregation_tasks[key][2].append(result)

# once done, remove the data from the "queue"
del self.distributed_metrics[iteration]

for key, (iteration, eval_name, results) in aggregation_tasks.items():
self.aggregate_and_report_task(key, iteration, eval_name, results)
2 changes: 1 addition & 1 deletion src/common/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def log_metric(self, key, value, step=None):

key = self._remove_non_allowed_chars(key)

self._logger.debug(f"mlflow[session={self._session_name}].log_metric({key},{value})")
self._logger.debug(f"mlflow[session={self._session_name}].log_metric({key},{value},step={step})")
# NOTE: there's a limit to the name of a metric
if len(key) > 50:
key = key[:50]
Expand Down
16 changes: 12 additions & 4 deletions src/scripts/training/lightgbm_python/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import lightgbm
from collections import namedtuple

import lightgbm
import mpi4py

# Add the right path to PYTHONPATH
# so that you can import from common.*
COMMON_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
Expand All @@ -25,7 +28,7 @@
# useful imports from common
from common.components import RunnableScript
from common.io import get_all_files
from common.lightgbm_utils import LightGBMCallbackHandler
from common.lightgbm_utils import LightGBMDistributedCallbackHandler
from common.distributed import MultiNodeScript

class LightGBMPythonMpiTrainingScript(MultiNodeScript):
Expand Down Expand Up @@ -182,10 +185,12 @@ def run(self, args, logger, metrics_logger, unknown_args):
# figure out the lgbm params from cli args + mpi config
lgbm_params = self.load_lgbm_params_from_cli(args, mpi_config)

# create a handler for the metrics callbacks
callbacks_handler = LightGBMCallbackHandler(
# create a handler
callbacks_handler = LightGBMDistributedCallbackHandler(
metrics_logger=metrics_logger,
metrics_prefix=f"node_{mpi_config.world_rank}/"
mpi_comm = self.mpi_config.mpi_comm,
world_rank=self.mpi_config.world_rank,
world_size=self.mpi_config.world_size
)

# make sure the output argument exists
Expand Down Expand Up @@ -242,6 +247,9 @@ def run(self, args, logger, metrics_logger, unknown_args):
logger.info(f"Writing model in {args.export_model}")
booster.save_model(args.export_model)

# finalize all remaining metrics
callbacks_handler.finalize()


def get_arg_parser(parser=None):
""" To ensure compatibility with shrike unit tests """
Expand Down