diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 26672bafd..8bd577ee8 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -382,7 +382,22 @@ def _load_core_library(): build_rocm_version = list(filter(lambda f: f.startswith("ROCM_VERSION:"), build_info)) if build_rocm_version: build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2] - assert (rocm_version == build_rocm_version), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" + # Strict by default. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass. + allow_rocm_mismatch = os.getenv("NVTE_ALLOW_ROCM_MISMATCH", "0").strip().lower() in ("1", "true", "yes") + mismatch_detected = rocm_version != build_rocm_version + if allow_rocm_mismatch and mismatch_detected: + _logger.warning( + "NVTE_ALLOW_ROCM_MISMATCH is enabled. Proceeding despite ROCm runtime/build " + "version mismatch (runtime=%s, build=%s).", + ".".join(rocm_version), + ".".join(build_rocm_version), + ) + elif mismatch_detected: + raise RuntimeError( + f"ROCm {'.'.join(rocm_version)} is detected but the library is built for " + f"{'.'.join(build_rocm_version)}. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass " + "this check at your own risk." + ) except FileNotFoundError: pass