From f17f47333a71ccbbaf756a2285b2c1042906bec9 Mon Sep 17 00:00:00 2001 From: jiagaoxiang Date: Sat, 28 Feb 2026 00:01:40 +0000 Subject: [PATCH 1/6] Enhance ROCm version check: allow mismatch with NVTE_ALLOW_ROCM_MISMATCH environment variable --- transformer_engine/common/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 26672bafd..e032ac72a 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -382,7 +382,10 @@ def _load_core_library(): build_rocm_version = list(filter(lambda f: f.startswith("ROCM_VERSION:"), build_info)) if build_rocm_version: build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2] - assert (rocm_version == build_rocm_version), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" + # Strict by default. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass. + allow_rocm_mismatch = os.getenv("NVTE_ALLOW_ROCM_MISMATCH", "0").lower() in ("1", "true", "yes") + if not allow_rocm_mismatch: + assert (rocm_version == build_rocm_version), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" except FileNotFoundError: pass From e35e02c10620d2e3e480f02f1f61ebc9c1bf424d Mon Sep 17 00:00:00 2001 From: jiagaoxiang Date: Sat, 28 Feb 2026 01:58:55 +0000 Subject: [PATCH 2/6] Clarify ROCm mismatch bypass in assertion error. Mention NVTE_ALLOW_ROCM_MISMATCH=1 directly in the mismatch assertion message so users can self-serve when runtime/build ROCm versions differ. Made-with: Cursor --- transformer_engine/common/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index e032ac72a..894bbd49e 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -385,7 +385,11 @@ def _load_core_library(): # Strict by default. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass. allow_rocm_mismatch = os.getenv("NVTE_ALLOW_ROCM_MISMATCH", "0").lower() in ("1", "true", "yes") if not allow_rocm_mismatch: - assert (rocm_version == build_rocm_version), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" + assert rocm_version == build_rocm_version, ( + f"ROCm {'.'.join(rocm_version)} is detected but the library is built for " + f"{'.'.join(build_rocm_version)}. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass " + "this check at your own risk." + ) except FileNotFoundError: pass From be43d54400f09f47d9cffe2a164ea3f1a3b373d6 Mon Sep 17 00:00:00 2001 From: Doug J Date: Fri, 27 Feb 2026 18:54:10 -0800 Subject: [PATCH 3/6] Update transformer_engine/common/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- transformer_engine/common/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 894bbd49e..4f51b700c 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -384,8 +384,8 @@ def _load_core_library(): build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2] # Strict by default. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass. allow_rocm_mismatch = os.getenv("NVTE_ALLOW_ROCM_MISMATCH", "0").lower() in ("1", "true", "yes") - if not allow_rocm_mismatch: - assert rocm_version == build_rocm_version, ( + if not allow_rocm_mismatch and rocm_version != build_rocm_version: + raise RuntimeError( f"ROCm {'.'.join(rocm_version)} is detected but the library is built for " f"{'.'.join(build_rocm_version)}. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass " "this check at your own risk." From bfe4c907494636383ff5991003086756bb6e6dbd Mon Sep 17 00:00:00 2001 From: jiagaoxiang Date: Sat, 28 Feb 2026 08:43:28 +0000 Subject: [PATCH 4/6] Log ROCm mismatch status when bypass is enabled. Emit a warning whenever NVTE_ALLOW_ROCM_MISMATCH is set and include whether a mismatch was detected plus runtime/build versions. Made-with: Cursor --- transformer_engine/common/__init__.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 4f51b700c..9df0ccbe1 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -384,7 +384,16 @@ def _load_core_library(): build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2] # Strict by default. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass. allow_rocm_mismatch = os.getenv("NVTE_ALLOW_ROCM_MISMATCH", "0").lower() in ("1", "true", "yes") - if not allow_rocm_mismatch and rocm_version != build_rocm_version: + mismatch_detected = rocm_version != build_rocm_version + if allow_rocm_mismatch: + _logger.warning( + "NVTE_ALLOW_ROCM_MISMATCH is enabled. ROCm runtime/build mismatch detected=%s " + "(runtime=%s, build=%s).", + mismatch_detected, + ".".join(rocm_version), + ".".join(build_rocm_version), + ) + elif mismatch_detected: raise RuntimeError( f"ROCm {'.'.join(rocm_version)} is detected but the library is built for " f"{'.'.join(build_rocm_version)}. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass " From 50ab2e81fcb870f6d7817f60a21a3bcde03e6cd6 Mon Sep 17 00:00:00 2001 From: jiagaoxiang Date: Sat, 28 Feb 2026 08:56:40 +0000 Subject: [PATCH 5/6] Only warn when mismatch bypass actually applies. Emit the ROCm mismatch warning only when NVTE_ALLOW_ROCM_MISMATCH is enabled and a runtime/build mismatch is detected, avoiding noisy logs when versions already match. Made-with: Cursor --- transformer_engine/common/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 9df0ccbe1..10a694f05 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -385,11 +385,10 @@ def _load_core_library(): # Strict by default. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass. allow_rocm_mismatch = os.getenv("NVTE_ALLOW_ROCM_MISMATCH", "0").lower() in ("1", "true", "yes") mismatch_detected = rocm_version != build_rocm_version - if allow_rocm_mismatch: + if allow_rocm_mismatch and mismatch_detected: _logger.warning( - "NVTE_ALLOW_ROCM_MISMATCH is enabled. ROCm runtime/build mismatch detected=%s " - "(runtime=%s, build=%s).", - mismatch_detected, + "NVTE_ALLOW_ROCM_MISMATCH is enabled. Proceeding despite ROCm runtime/build " + "version mismatch (runtime=%s, build=%s).", ".".join(rocm_version), ".".join(build_rocm_version), ) From e117c9d5a7f28b124fd5db4fd1a87d1eb0238af3 Mon Sep 17 00:00:00 2001 From: Doug J Date: Sat, 28 Feb 2026 01:05:10 -0800 Subject: [PATCH 6/6] Update transformer_engine/common/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- transformer_engine/common/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 10a694f05..8bd577ee8 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -383,7 +383,7 @@ def _load_core_library(): if build_rocm_version: build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2] # Strict by default. Set NVTE_ALLOW_ROCM_MISMATCH=1 to bypass. - allow_rocm_mismatch = os.getenv("NVTE_ALLOW_ROCM_MISMATCH", "0").lower() in ("1", "true", "yes") + allow_rocm_mismatch = os.getenv("NVTE_ALLOW_ROCM_MISMATCH", "0").strip().lower() in ("1", "true", "yes") mismatch_detected = rocm_version != build_rocm_version if allow_rocm_mismatch and mismatch_detected: _logger.warning(