From 64840640ed931870c86b8bd50212bf76d074bb96 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Mon, 2 Mar 2026 12:18:32 -0500 Subject: [PATCH] Check only major ROCm version on load --- transformer_engine/common/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 26672bafd..519b2c88a 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -133,7 +133,7 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" - te_cuda_vers = "rocm" if te_rocm_build else "cu12" + te_core_tag = "rocm" if te_rocm_build else "cu12" # If the framework extension pip package is installed, it means that TE is installed via # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework @@ -143,24 +143,24 @@ def load_framework_extension(framework: str) -> None: "transformer_engine" ), "Could not find `transformer-engine`." assert _is_pip_package_installed( - f"transformer_engine_{te_cuda_vers}" - ), f"Could not find `transformer-engine-{te_cuda_vers}`." + f"transformer_engine_{te_core_tag}" + ), f"Could not find `transformer-engine-{te_core_tag}`." assert ( version(module_name) == version("transformer-engine") - == version(f"transformer-engine-{te_cuda_vers}") + == version(f"transformer-engine-{te_core_tag}") ), ( "TransformerEngine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-{te_cuda_vers}" - f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using " + f" v{version('transformer-engine')}, and transformer-engine-{te_core_tag}" + f" v{version(f'transformer-engine-{te_core_tag}')}. Install transformer-engine using " f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" ) # If the core package is installed via PyPI, log if # the framework extension is not found from PyPI. # Note: Should we error? This is a rare use case. - if _is_pip_package_installed(f"transformer-engine-{te_cuda_vers}"): + if _is_pip_package_installed(f"transformer-engine-{te_core_tag}"): if not _is_pip_package_installed(module_name): _logger.info( "Could not find package %s. Install transformer-engine using " @@ -382,7 +382,7 @@ def _load_core_library(): build_rocm_version = list(filter(lambda f: f.startswith("ROCM_VERSION:"), build_info)) if build_rocm_version: build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2] - assert (rocm_version == build_rocm_version), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" + assert (rocm_version[0] == build_rocm_version[0]), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" except FileNotFoundError: pass