diff --git a/easybuild/tools/toolchain/mpi.py b/easybuild/tools/toolchain/mpi.py index 2f501d2fec..12bd43bbea 100644 --- a/easybuild/tools/toolchain/mpi.py +++ b/easybuild/tools/toolchain/mpi.py @@ -48,7 +48,7 @@ _log = fancylogger.getLogger('tools.toolchain.mpi', fname=False) -def get_mpi_cmd_template(mpi_family, params, mpi_version=None): +def get_mpi_cmd_template(mpi_family, params, mpi_version=None, oversubscribe=False): """ Return template for MPI command, for specified MPI family. @@ -123,6 +123,38 @@ def get_mpi_cmd_template(mpi_family, params, mpi_version=None): else: raise EasyBuildError("Don't know which template MPI command to use for MPI family '%s'", mpi_family) + if oversubscribe: + osub_cmd = '' + if mpi_family in [toolchain.OPENMPI]: + if mpi_version is None: + raise EasyBuildError("OpenMPI version unknown, can't determine how to handle oversubscription!") + if LooseVersion(mpi_version) < '5': + varname = 'OMPI_MCA_rmaps_base_oversubscribe' + varvalue = os.getenv(varname) + if varvalue and varvalue != '1': + _log.warning("Overwriting existing %s=%s with %s=1", varname, varvalue, varname) + osub_cmd = f'{varname}=1' + else: + varname = 'PRTE_MCA_rmaps_default_mapping_policy' + varvalue = os.getenv(varname, '') + + # This logic should account for adding a `:oversubscribe` flag to the mapping policy if not + # already present. + # See https://docs.open-mpi.org/en/main/man-openmpi/man1/mpirun.1.html#the-map-by-option + flags = varvalue.lower().split(':') + if 'oversubscribe' not in flags: + varvalue = f'{varvalue}:oversubscribe' + osub_cmd = f'{varname}={varvalue}' + elif mpi_family in [toolchain.INTELMPI]: + _log.info("INTELMPI always oversubscribe by default, nothing to do...") + elif mpi_family in [toolchain.MVAPICH2, toolchain.MPICH, toolchain.MPICH2]: + _log.info("MPICH always oversubscribe by default, nothing to do...") + else: + raise EasyBuildError("Oversubscribe not supported for MPI family '%s'", mpi_family) + + mpi_cmd_template = f'%(oversubscribe)s {mpi_cmd_template}' + params.update({'oversubscribe': osub_cmd}) + missing = [] for key in sorted(params.keys()): tmpl = '%(' + key + ')s' @@ -270,7 +302,7 @@ def mpi_cmd_prefix(self, nr_ranks=1): return result - def mpi_cmd_for(self, cmd, nr_ranks): + def mpi_cmd_for(self, cmd, nr_ranks, oversubscribe=False): """Construct an MPI command for the given command and number of ranks.""" # parameter values for mpirun command @@ -281,20 +313,23 @@ def mpi_cmd_for(self, cmd, nr_ranks): mpi_family = self.mpi_family() - mpi_version = None + # this fails when it's done too early (before modules for toolchain/dependencies are loaded), + # but it's safe to ignore this + if self.MPI_MODULE_NAME is None: + mpi_version = None + else: + mpi_version = self.get_software_version(self.MPI_MODULE_NAME, required=False)[0] if mpi_family == toolchain.INTELMPI: - # for Intel MPI, try to determine impi version - # this fails when it's done too early (before modules for toolchain/dependencies are loaded), - # but it's safe to ignore this - mpi_version = self.get_software_version(self.MPI_MODULE_NAME, required=False)[0] if not mpi_version: self.log.debug("Ignoring error when trying to determine %s version", self.MPI_MODULE_NAME) # impi version is required to determine correct MPI command template, # so we have to return early if we couldn't determine the impi version... return None - mpi_cmd_template, params = get_mpi_cmd_template(mpi_family, params, mpi_version=mpi_version) + mpi_cmd_template, params = get_mpi_cmd_template( + mpi_family, params, mpi_version=mpi_version, oversubscribe=oversubscribe + ) self.log.info("Using MPI command template '%s' (params: %s)", mpi_cmd_template, params) try: diff --git a/test/framework/toolchain.py b/test/framework/toolchain.py index e306e9df30..e3b6b29f68 100644 --- a/test/framework/toolchain.py +++ b/test/framework/toolchain.py @@ -1884,6 +1884,41 @@ def test_get_mpi_cmd_template(self): self.assertTrue(regex.match(nodesfile), "'%s' should match pattern '%s'" % (nodesfile, regex.pattern)) self.assertExists(nodesfile.split(' ')[1]) + # Test oversubscription + # With OpenMPI < 5 + mpi_cmd_tmpl, params = get_mpi_cmd_template(toolchain.OPENMPI, {}, mpi_version='4.9', oversubscribe=True) + self.assertTrue('%(oversubscribe)s' in mpi_cmd_tmpl) + self.assertEqual(params['oversubscribe'], 'OMPI_MCA_rmaps_base_oversubscribe=1') + + # With OpenMPI >= 5 + mpi_cmd_tmpl, params = get_mpi_cmd_template(toolchain.OPENMPI, {}, mpi_version='5', oversubscribe=True) + self.assertTrue('%(oversubscribe)s' in mpi_cmd_tmpl) + self.assertEqual(params['oversubscribe'], 'PRTE_MCA_rmaps_default_mapping_policy=:oversubscribe') + + # With OpenMPI >= 5 and pre-existing PRTE_MCA_rmaps_default_mapping_policy (override ppr) + os.environ['PRTE_MCA_rmaps_default_mapping_policy'] = 'ppr:4:package' + mpi_cmd_tmpl, params = get_mpi_cmd_template(toolchain.OPENMPI, {}, mpi_version='5', oversubscribe=True) + self.assertTrue('%(oversubscribe)s' in mpi_cmd_tmpl) + self.assertEqual(params['oversubscribe'], 'PRTE_MCA_rmaps_default_mapping_policy=ppr:4:package:oversubscribe') + + # With OpenMPI >= 5 and pre-existing PRTE_MCA_rmaps_default_mapping_policy (add to unit) + os.environ['PRTE_MCA_rmaps_default_mapping_policy'] = 'package' + mpi_cmd_tmpl, params = get_mpi_cmd_template(toolchain.OPENMPI, {}, mpi_version='5', oversubscribe=True) + self.assertTrue('%(oversubscribe)s' in mpi_cmd_tmpl) + self.assertEqual(params['oversubscribe'], 'PRTE_MCA_rmaps_default_mapping_policy=package:oversubscribe') + + # With OpenMPI >= 5 and pre-existing PRTE_MCA_rmaps_default_mapping_policy (add to unit) + os.environ['PRTE_MCA_rmaps_default_mapping_policy'] = 'core:oversubscribe' + mpi_cmd_tmpl, params = get_mpi_cmd_template(toolchain.OPENMPI, {}, mpi_version='5', oversubscribe=True) + self.assertTrue('%(oversubscribe)s' in mpi_cmd_tmpl) + self.assertEqual(params['oversubscribe'], 'PRTE_MCA_rmaps_default_mapping_policy=core:oversubscribe') + + # With IntelMPI and MPICH + for mpi_fam in [toolchain.INTELMPI, toolchain.MPICH, toolchain.MPICH2, toolchain.MVAPICH2]: + mpi_cmd_tmpl, params = get_mpi_cmd_template(mpi_fam, input_params, mpi_version='1.0', oversubscribe=True) + self.assertTrue('%(oversubscribe)s' in mpi_cmd_tmpl) + self.assertEqual(params['oversubscribe'], '') + def test_prepare_deps(self): """Test preparing for a toolchain when dependencies are involved.""" tc = self.get_toolchain('GCC', version='6.4.0-2.28')