Skip to content
This repository was archived by the owner on Apr 8, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,5 @@ outputs/
# ignore aml references
conf/aml/
conf/compute/
conf/experiments/prod/
src/scripts/inferencing/custom_win_cli/static_binaries/
45 changes: 31 additions & 14 deletions src/common/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
"""
LightGBM/Python training script
"""
import os
import logging
import traceback
from .components import RunnableScript
from dataclasses import dataclass
from omegaconf import MISSING

from .perf import PerformanceMetricsCollector, PerfReportPlotter

@dataclass
class mpi_config_class:
world_size: int = MISSING
world_rank: int = MISSING
mpi_available: bool = MISSING
main_node: bool = MISSING
world_size: int = 1
world_rank: int = 0
mpi_available: bool = False
main_node: bool = True


class MPIHandler():
"""Handling MPI initialization in a separate class
Expand All @@ -42,22 +43,39 @@ def _mpi_import(cls):
def initialize(self):
# doing our own initialization of MPI to have fine-grain control
self._mpi_module = self._mpi_import()
self.comm = self._mpi_module.COMM_WORLD

if self._mpi_init_mode is None:
self._mpi_init_mode = self._mpi_module.THREAD_MULTIPLE
# use simple env vars instead
self.logger.info(f"no MPI init, using environment variables instead")
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
world_rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))

self._mpi_config = mpi_config_class(
world_size, # world_size
world_rank, # world_rank
(world_size > 1), # mpi_available
(world_rank == 0), # main_node
)
self.comm = None
else:
# use mpi to detect mpi config
self.logger.info(f"Running MPI.Init_thread(required={self._mpi_init_mode})")
try:
self._mpi_module.Init_thread(required=self._mpi_init_mode)
except self._mpi_module.Exception:
self.logger.warning(f"Exception occured during MPI initialization:\n{traceback.format_exc()}")

try:
self._mpi_module.Init_thread(required=self._mpi_init_mode)
except self._mpi_module.Exception:
self.logger.warning(f"Exception occured during MPI initialization:\n{traceback.format_exc()}")
self.comm = self._mpi_module.COMM_WORLD
self._mpi_config = self.detect_mpi_config()

self._mpi_config = self.detect_mpi_config()
logging.getLogger().info(f"MPI detection results: {self._mpi_config}")

def finalize(self):
if self._mpi_module.Is_initialized():
if self._mpi_module.Is_initialized() and not self._mpi_module.Is_finalized():
self.logger.info("MPI was initialized, calling MPI.finalize()")
self._mpi_module.Finalize()
else:
self.logger.warning(f"MPIHandler.finalize() was called, but MPI.Is_initialized={self._mpi_module.Is_initialized()} and MPI.Is_finalized={self._mpi_module.Is_finalized()}")

def mpi_config(self):
return self._mpi_config
Expand All @@ -77,7 +95,6 @@ def detect_mpi_config(self):
(self.comm.Get_size() > 1), # mpi_available
(self.comm.Get_rank() == 0), # main_node
)
logging.getLogger().info(f"MPI detection results: {mpi_config}")
except:
mpi_config = mpi_config_class(
1, # world_size
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/training/lightgbm_python/default.dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM mcr.microsoft.com/azureml/openmpi3.1.2-ubuntu18.04:20210615.v1
FROM mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04:latest
LABEL lightgbmbenchmark.linux.cpu.mpi.pip.version="3.3.0/20211111.1"

# Those arguments will NOT be used by AzureML
Expand Down
29 changes: 27 additions & 2 deletions tests/common/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_multi_node_script_failure(mpi_handler_mock):
)


def test_mpi_handler():
def test_mpi_handler_mpi_init():
"""Tests the MPIHandler class"""
# create MPI module mock
mpi_module_mock = Mock()
Expand All @@ -113,7 +113,7 @@ def test_mpi_handler():
with patch.object(MPIHandler, "_mpi_import") as mpi_import_mock:
mpi_import_mock.return_value = mpi_module_mock

mpi_handler = MPIHandler()
mpi_handler = MPIHandler(mpi_init_mode=3) # MPI.THREAD_MULTIPLE
mpi_handler.initialize()
mpi_config = mpi_handler.mpi_config()
mpi_handler.finalize()
Expand All @@ -123,3 +123,28 @@ def test_mpi_handler():
assert mpi_config.world_size == 10
assert mpi_config.mpi_available == True
assert mpi_config.main_node == False

def test_mpi_handler_no_mpi_init():
"""Tests the MPIHandler class"""
# create MPI module mock
mpi_module_mock = Mock()
mpi_module_mock.COMM_WORLD = Mock()
mpi_module_mock.COMM_WORLD.Get_size.return_value = 10 # different value just to make the point
mpi_module_mock.COMM_WORLD.Get_rank.return_value = 3 # different value just to make the point
mpi_module_mock.THREAD_MULTIPLE = 3

# patch _mpi_import to return our MPI module mock
with patch.object(MPIHandler, "_mpi_import") as mpi_import_mock:
with patch.dict(os.environ, {"OMPI_COMM_WORLD_SIZE": "6", "OMPI_COMM_WORLD_RANK": "2"}):
mpi_import_mock.return_value = mpi_module_mock

mpi_handler = MPIHandler(mpi_init_mode=None)
mpi_handler.initialize()
mpi_config = mpi_handler.mpi_config()
mpi_handler.finalize()

# test this random config
assert mpi_config.world_rank == 2
assert mpi_config.world_size == 6
assert mpi_config.mpi_available == True
assert mpi_config.main_node == False