Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[submodule "3rdparty/composable_kernel"]
path = 3rdparty/composable_kernel
url = https://github.com/ROCmSoftwarePlatform/composable_kernel.git
branch = develop
branch = navi3_rel
[submodule "3rdparty/picojson"]
path = 3rdparty/picojson
url = https://github.com/kazuho/picojson.git
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
4 changes: 4 additions & 0 deletions fx2ait/fx2ait/csrc/AITModelImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@

#include "ATen/Context.h" // @manual
#ifdef __HIP_PLATFORM_HCC__
#include "rocm_device_functions.h"
#include "ATen/hip/HIPContext.h"
#include "c10/core/CPUAllocator.h"
#include "c10/hip/HIPStream.h"
#else
#include "cuda_device_functions.h"
#include "ATen/cuda/CUDAContext.h"
#include "c10/core/CPUAllocator.h"
#include "c10/cuda/CUDAStream.h"
#endif


#ifdef FBCODE_AIT
#include "folly/MapUtil.h"
#endif
Expand Down Expand Up @@ -667,6 +670,7 @@ void AITModelImpl::updateConstantsWithWeights(
decltype(&cudaStreamDestroy)>;
StreamGuard constants_stream_guard{constants_stream, cudaStreamDestroy};
#endif

AIT_CHECK(setManyConstantsDoubleBufferFunc_(
model_handle_,
/*stream=*/reinterpret_cast<AITemplateStreamOpaque*>(constants_stream),
Expand Down
3 changes: 2 additions & 1 deletion python/aitemplate/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,10 +1033,11 @@ def generate_source(self) -> Dict[str, str]:
The dictionary returned is a map from filename -> contents.
"""
device_functions_header_name = f"{self.target.name()}_device_functions.h"
includes_header_name = f"{self.target.name()}_includes.h"
result = {}
result[
"device_functions-generated.h"
] = f'#include "{device_functions_header_name}"'
] = f'#include "{device_functions_header_name}"\n#include "{includes_header_name}"'

result["model-generated.h"] = self.generate_model()

Expand Down
1 change: 1 addition & 0 deletions python/aitemplate/backend/profiler_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def push(self, cmds: List[str], process_result_callback: Callable):
future = self._executor.submit(
run_task, cmds, self._device_queue, self._dev_select_flag
)
_LOGGER.info(f"The result of profile executor is {future.result()}")

# done callbacks are used to collect profiler results for postprocessing
# they are launched asynchronously, in a separate thread,
Expand Down
6 changes: 5 additions & 1 deletion python/aitemplate/backend/rocm/conv2d/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@

HEADER_CODE = jinja2.Template(
"""
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#if 1
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#endif
"""
)

Expand Down
6 changes: 5 additions & 1 deletion python/aitemplate/backend/rocm/conv2d/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
from aitemplate.backend import registry
from aitemplate.backend.rocm.conv2d import common
from aitemplate.backend.target import Target

# pylint: disable=C0103,C0415,W0613

Expand All @@ -39,7 +40,10 @@ def conv2d_config(func_attrs):
"""
import ck_lib

op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu
if Target.current().get_device_name() == "gfx1100":
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma
else:
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops
extra_kind = ck_lib.library.TensorOperation.PassThrough
func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind)

Expand Down
6 changes: 5 additions & 1 deletion python/aitemplate/backend/rocm/conv2d/conv2d_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
from aitemplate.backend import registry
from aitemplate.backend.rocm.conv2d import common
from aitemplate.backend.target import Target

# pylint: disable=C0103,C0415,W0613

Expand All @@ -39,7 +40,10 @@ def conv2d_config(func_attrs):
"""
import ck_lib

op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu
if Target.current().get_device_name() == "gfx1100":
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma
else:
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops
extra_kind = ck_lib.library.TensorOperation.Add
func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind)

Expand Down
6 changes: 5 additions & 1 deletion python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
from ... import registry
from . import common
from aitemplate.backend.target import Target

# pylint: disable=C0103,C0415,W0613,C0301

Expand All @@ -25,7 +26,10 @@
def conv2d_config(func_attrs):
import ck_lib

op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu
if Target.current().get_device_name() == "gfx1100":
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma
else:
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops
extra_kind = ck_lib.library.TensorOperation.AddAdd
func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind)

Expand Down
12 changes: 10 additions & 2 deletions python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@

from aitemplate.backend import registry
from aitemplate.backend.rocm.conv2d import common
from aitemplate.backend.target import Target

# pylint: disable=C0103,C0415,W0613

EXTRA_CODE = jinja2.Template(
"""
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#if 1
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#endif


namespace ck {
Expand Down Expand Up @@ -65,7 +70,10 @@ def conv2d_config(func_attrs):
"""
import ck_lib

op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu
if Target.current().get_device_name() == "gfx1100":
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma
else:
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops
extra_kind = ck_lib.library.TensorOperation.AddAddRelu
func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind)

Expand Down
6 changes: 5 additions & 1 deletion python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
from aitemplate.backend import registry
from aitemplate.backend.rocm.conv2d import common
from aitemplate.backend.target import Target

# pylint: disable=C0103,C0415,W0613

Expand All @@ -40,7 +41,10 @@ def conv2d_config(func_attrs):
"""
import ck_lib

op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu
if Target.current().get_device_name() == "gfx1100":
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma
else:
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops
extra_kind = ck_lib.library.TensorOperation.AddRelu
func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind)

Expand Down
6 changes: 5 additions & 1 deletion python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from aitemplate.backend import registry
from aitemplate.backend.rocm.conv2d import common
from aitemplate.backend.target import Target

# pylint: disable=C0103,C0415,W0613

Expand Down Expand Up @@ -85,7 +86,10 @@ def conv2d_config(func_attrs):
"""
import ck_lib

op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu
if Target.current().get_device_name() == "gfx1100":
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma
else:
op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops
extra_kind = ck_lib.library.TensorOperation.AddSigmoid
func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@
EXTRA_HEADER_TEMPLATE = jinja2.Template(
"""
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#if 1
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#endif
"""
)

Expand Down
25 changes: 23 additions & 2 deletions python/aitemplate/backend/rocm/gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,34 @@
EXTRA_HEADER_TEMPLATE = jinja2.Template(
"""
{% if gemm_flag == "" %}
{% if rocm_device_name == "gfx1100" %}
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
{% else %}
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
{% endif %}
{% elif gemm_flag == "permute_m2n3" %}
{% if rocm_device_name == "gfx1100" %}
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp"
{% else %}
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% endif %}
{% elif "bias" in gemm_flag or has_d0 %}
{% if rocm_device_name == "gfx1100" %}
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
{% else %}
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
{% endif %}
{% if gemm_flag == "bias_permute" %}
{% if rocm_device_name != "gfx1100" %}
#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/gemm_specialization.hpp"
{% endif %}
{% elif gemm_flag in ["bias_permute_m2n3", "bias_permute_m3n2"] %}
{% if rocm_device_name == "gfx1100" %}
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp"
{% else %}
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% endif %}
{% endif %}
{% endif %}
"""
Expand Down Expand Up @@ -652,6 +670,7 @@ def gen_profiler(
file_pairs = []
has_d0_flag = has_d0(func_attrs)
has_d1_flag = has_d1(func_attrs)
rocm_device_name = Target.current().get_device_name()

for op_name, op in op_instance.items():
config = emit_instance(op)
Expand All @@ -672,7 +691,7 @@ def gen_profiler(
is_profiler=True,
)
extra_header = extra_header_template.render(
gemm_flag=gemm_flag, has_d0=has_d0_flag
rocm_device_name=rocm_device_name, gemm_flag=gemm_flag, has_d0=has_d0_flag
)
op_func = SRC_TEMPLATE.render(
instances=instance,
Expand Down Expand Up @@ -786,6 +805,8 @@ def gen_function(
instance_decl = ""
has_d0_flag = has_d0(func_attrs)
has_d1_flag = has_d1(func_attrs)
rocm_device_name = Target.current().get_device_name()

for key, value in exec_path.items():
fname = "f" + sha1(key.encode()).hexdigest()
algo = value.algo
Expand Down Expand Up @@ -829,7 +850,7 @@ def gen_function(
exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program)
exec_paths += exec_inst
extra_header = extra_header_template.render(
gemm_flag=gemm_flag, has_d0=has_d0(func_attrs)
rocm_device_name=rocm_device_name, gemm_flag=gemm_flag, has_d0=has_d0(func_attrs)
)
pdims = len(func_attrs["shape"]) if func_attrs.get("shape") is not None else 0
return SRC_TEMPLATE.render(
Expand Down
10 changes: 10 additions & 0 deletions python/aitemplate/backend/rocm/target_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def _build_compile_options(self):
"-fvisibility=hidden",
"-std=c++17",
"-w",
"-mcumode",
"-mno-wavefrontsize64",
"-DCK_TIME_KERNEL=0",
"-Xclang -mlink-builtin-bitcode -Xclang {}/amdgcn/bitcode/oclc_abi_version_400.bc".format(
self._pkg_path()
Expand All @@ -132,6 +134,9 @@ def _build_compile_options(self):
elif self._arch in {"GFX90a", "gfx90a"}:
options.append("-DCK_AMD_GPU_GFX90A")
options.append("--offload-arch=gfx90a")
elif self._arch in {"GFX1100", "gfx1100"}:
options.append("-DCK_AMD_GPU_GFX1100")
options.append("--offload-arch=gfx1100")
else:
raise RuntimeError("Unsupported GPU Arch")
for path in ck_paths:
Expand Down Expand Up @@ -297,6 +302,8 @@ def _build_compile_options(self):
"-fvisibility=hidden",
"-std=c++17",
"-w",
"-mcumode",
"-mno-wavefrontsize64",
"-DCK_TIME_KERNEL=0",
"--hip-version=5.2.0",
]
Expand All @@ -310,6 +317,9 @@ def _build_compile_options(self):
elif self._arch in {"GFX90a", "gfx90a"}:
options.append("-DCK_AMD_GPU_GFX90A")
options.append("--cuda-gpu-arch=gfx90a")
elif self._arch in {"GFX1100", "gfx1100"}:
options.append("-DCK_AMD_GPU_GFX1100")
options.append("--amdgpu-target=gfx1100")
else:
raise RuntimeError("Unsupported GPU Arch")
for path in ck_paths:
Expand Down
19 changes: 18 additions & 1 deletion python/aitemplate/backend/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, static_files_path: str):
Absolute path to the AIT static/ directory
"""
self._target_type = -1
self._device_name = ""
self._template_path = ""
self._compile_cmd = ""
self._cache_path = ""
Expand All @@ -84,7 +85,7 @@ def __enter__(self):
self._load_profile_cache()
global CURRENT_TARGET
if CURRENT_TARGET is not None:
raise RuntimeError("Target has been set.")
raise RuntimeError(f"Target has been set {CURRENT_TARGET}")
assert self._target_type > 0
CURRENT_TARGET = self

Expand Down Expand Up @@ -138,6 +139,22 @@ def name(self) -> str:
"""
return TargetType(self._target_type).name

def get_device_name(self) -> str:
"""Return the device name of the target.

Returns
-------
str
The device name of the target.
"""
from ..testing.detect_target import _detect_cuda, _detect_rocm

if self.name() == "rocm":
self._device_name = _detect_rocm()
else:
self._device_name = _detect_cuda()
return self._device_name

def cc(self):
"""Compiler for this target.

Expand Down
2 changes: 2 additions & 0 deletions python/aitemplate/testing/detect_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def _detect_rocm():
proc = Popen(["rocminfo"], stdout=PIPE, stderr=PIPE)
stdout, stderr = proc.communicate()
stdout = stdout.decode("utf-8")
if "gfx1100" in stdout:
return "gfx1100"
if "gfx90a" in stdout:
return "gfx90a"
if "gfx908" in stdout:
Expand Down
Loading