diff --git a/.gitignore b/.gitignore index 28060d7e..794998ac 100644 --- a/.gitignore +++ b/.gitignore @@ -140,4 +140,5 @@ outputs/ # ignore aml references conf/aml/ conf/compute/ +conf/experiments/prod/ src/scripts/inferencing/custom_win_cli/static_binaries/ diff --git a/src/common/distributed.py b/src/common/distributed.py index 5cd6c6af..9b60e2f8 100644 --- a/src/common/distributed.py +++ b/src/common/distributed.py @@ -4,20 +4,21 @@ """ LightGBM/Python training script """ +import os import logging import traceback from .components import RunnableScript from dataclasses import dataclass -from omegaconf import MISSING from .perf import PerformanceMetricsCollector, PerfReportPlotter @dataclass class mpi_config_class: - world_size: int = MISSING - world_rank: int = MISSING - mpi_available: bool = MISSING - main_node: bool = MISSING + world_size: int = 1 + world_rank: int = 0 + mpi_available: bool = False + main_node: bool = True + class MPIHandler(): """Handling MPI initialization in a separate class @@ -42,22 +43,39 @@ def _mpi_import(cls): def initialize(self): # doing our own initialization of MPI to have fine-grain control self._mpi_module = self._mpi_import() - self.comm = self._mpi_module.COMM_WORLD if self._mpi_init_mode is None: - self._mpi_init_mode = self._mpi_module.THREAD_MULTIPLE + # use simple env vars instead + self.logger.info(f"no MPI init, using environment variables instead") + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1")) + world_rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0")) + + self._mpi_config = mpi_config_class( + world_size, # world_size + world_rank, # world_rank + (world_size > 1), # mpi_available + (world_rank == 0), # main_node + ) + self.comm = None + else: + # use mpi to detect mpi config + self.logger.info(f"Running MPI.Init_thread(required={self._mpi_init_mode})") + try: + self._mpi_module.Init_thread(required=self._mpi_init_mode) + except self._mpi_module.Exception: + self.logger.warning(f"Exception occured during MPI initialization:\n{traceback.format_exc()}") - try: - self._mpi_module.Init_thread(required=self._mpi_init_mode) - except self._mpi_module.Exception: - self.logger.warning(f"Exception occured during MPI initialization:\n{traceback.format_exc()}") + self.comm = self._mpi_module.COMM_WORLD + self._mpi_config = self.detect_mpi_config() - self._mpi_config = self.detect_mpi_config() + logging.getLogger().info(f"MPI detection results: {self._mpi_config}") def finalize(self): - if self._mpi_module.Is_initialized(): + if self._mpi_module.Is_initialized() and not self._mpi_module.Is_finalized(): self.logger.info("MPI was initialized, calling MPI.finalize()") self._mpi_module.Finalize() + else: + self.logger.warning(f"MPIHandler.finalize() was called, but MPI.Is_initialized={self._mpi_module.Is_initialized()} and MPI.Is_finalized={self._mpi_module.Is_finalized()}") def mpi_config(self): return self._mpi_config @@ -77,7 +95,6 @@ def detect_mpi_config(self): (self.comm.Get_size() > 1), # mpi_available (self.comm.Get_rank() == 0), # main_node ) - logging.getLogger().info(f"MPI detection results: {mpi_config}") except: mpi_config = mpi_config_class( 1, # world_size diff --git a/src/scripts/training/lightgbm_python/default.dockerfile b/src/scripts/training/lightgbm_python/default.dockerfile index 080f2b56..9f0db16d 100644 --- a/src/scripts/training/lightgbm_python/default.dockerfile +++ b/src/scripts/training/lightgbm_python/default.dockerfile @@ -1,4 +1,4 @@ -FROM mcr.microsoft.com/azureml/openmpi3.1.2-ubuntu18.04:20210615.v1 +FROM mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04:latest LABEL lightgbmbenchmark.linux.cpu.mpi.pip.version="3.3.0/20211111.1" # Those arguments will NOT be used by AzureML diff --git a/tests/common/test_distributed.py b/tests/common/test_distributed.py index f9ba53f2..05423e6a 100644 --- a/tests/common/test_distributed.py +++ b/tests/common/test_distributed.py @@ -100,7 +100,7 @@ def test_multi_node_script_failure(mpi_handler_mock): ) -def test_mpi_handler(): +def test_mpi_handler_mpi_init(): """Tests the MPIHandler class""" # create MPI module mock mpi_module_mock = Mock() @@ -113,7 +113,7 @@ def test_mpi_handler(): with patch.object(MPIHandler, "_mpi_import") as mpi_import_mock: mpi_import_mock.return_value = mpi_module_mock - mpi_handler = MPIHandler() + mpi_handler = MPIHandler(mpi_init_mode=3) # MPI.THREAD_MULTIPLE mpi_handler.initialize() mpi_config = mpi_handler.mpi_config() mpi_handler.finalize() @@ -123,3 +123,28 @@ def test_mpi_handler(): assert mpi_config.world_size == 10 assert mpi_config.mpi_available == True assert mpi_config.main_node == False + +def test_mpi_handler_no_mpi_init(): + """Tests the MPIHandler class""" + # create MPI module mock + mpi_module_mock = Mock() + mpi_module_mock.COMM_WORLD = Mock() + mpi_module_mock.COMM_WORLD.Get_size.return_value = 10 # different value just to make the point + mpi_module_mock.COMM_WORLD.Get_rank.return_value = 3 # different value just to make the point + mpi_module_mock.THREAD_MULTIPLE = 3 + + # patch _mpi_import to return our MPI module mock + with patch.object(MPIHandler, "_mpi_import") as mpi_import_mock: + with patch.dict(os.environ, {"OMPI_COMM_WORLD_SIZE": "6", "OMPI_COMM_WORLD_RANK": "2"}): + mpi_import_mock.return_value = mpi_module_mock + + mpi_handler = MPIHandler(mpi_init_mode=None) + mpi_handler.initialize() + mpi_config = mpi_handler.mpi_config() + mpi_handler.finalize() + + # test this random config + assert mpi_config.world_rank == 2 + assert mpi_config.world_size == 6 + assert mpi_config.mpi_available == True + assert mpi_config.main_node == False