Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -133,7 +133,7 @@ def load_framework_extension(framework: str) -> None:
if framework == "torch":
extra_dep_name = "pytorch"

te_cuda_vers = "rocm" if te_rocm_build else "cu12"
te_core_tag = "rocm" if te_rocm_build else "cu12"

# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
Expand All @@ -143,24 +143,24 @@ def load_framework_extension(framework: str) -> None:
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
f"transformer_engine_{te_cuda_vers}"
), f"Could not find `transformer-engine-{te_cuda_vers}`."
f"transformer_engine_{te_core_tag}"
), f"Could not find `transformer-engine-{te_core_tag}`."
assert (
version(module_name)
== version("transformer-engine")
== version(f"transformer-engine-{te_cuda_vers}")
== version(f"transformer-engine-{te_core_tag}")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-{te_cuda_vers}"
f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using "
f" v{version('transformer-engine')}, and transformer-engine-{te_core_tag}"
f" v{version(f'transformer-engine-{te_core_tag}')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)

# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed(f"transformer-engine-{te_cuda_vers}"):
if _is_pip_package_installed(f"transformer-engine-{te_core_tag}"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
Expand Down Expand Up @@ -382,7 +382,7 @@ def _load_core_library():
build_rocm_version = list(filter(lambda f: f.startswith("ROCM_VERSION:"), build_info))
if build_rocm_version:
build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2]
assert (rocm_version == build_rocm_version), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}"
assert (rocm_version[0] == build_rocm_version[0]), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}"
except FileNotFoundError:
pass