diff --git a/.gitmodules b/.gitmodules index 1272127de..8e4e174ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,7 +7,7 @@ [submodule "3rdparty/composable_kernel"] path = 3rdparty/composable_kernel url = https://github.com/ROCmSoftwarePlatform/composable_kernel.git - branch = develop + branch = navi3_rel [submodule "3rdparty/picojson"] path = 3rdparty/picojson url = https://github.com/kazuho/picojson.git diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index db49fc437..1fb4a4740 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit db49fc43797f80be1db2399dcd1a082dbf447736 +Subproject commit 1fb4a4740fdd81899521a0344f76503a16783292 diff --git a/fx2ait/fx2ait/csrc/AITModelImpl.cpp b/fx2ait/fx2ait/csrc/AITModelImpl.cpp index d47444302..1b2172c9f 100644 --- a/fx2ait/fx2ait/csrc/AITModelImpl.cpp +++ b/fx2ait/fx2ait/csrc/AITModelImpl.cpp @@ -21,15 +21,18 @@ #include "ATen/Context.h" // @manual #ifdef __HIP_PLATFORM_HCC__ +#include "rocm_device_functions.h" #include "ATen/hip/HIPContext.h" #include "c10/core/CPUAllocator.h" #include "c10/hip/HIPStream.h" #else +#include "cuda_device_functions.h" #include "ATen/cuda/CUDAContext.h" #include "c10/core/CPUAllocator.h" #include "c10/cuda/CUDAStream.h" #endif + #ifdef FBCODE_AIT #include "folly/MapUtil.h" #endif @@ -667,6 +670,7 @@ void AITModelImpl::updateConstantsWithWeights( decltype(&cudaStreamDestroy)>; StreamGuard constants_stream_guard{constants_stream, cudaStreamDestroy}; #endif + AIT_CHECK(setManyConstantsDoubleBufferFunc_( model_handle_, /*stream=*/reinterpret_cast(constants_stream), diff --git a/python/aitemplate/backend/codegen.py b/python/aitemplate/backend/codegen.py index bb37202c7..0f02a0e67 100644 --- a/python/aitemplate/backend/codegen.py +++ b/python/aitemplate/backend/codegen.py @@ -1033,10 +1033,11 @@ def generate_source(self) -> Dict[str, str]: The dictionary returned is a map from filename -> contents. """ device_functions_header_name = f"{self.target.name()}_device_functions.h" + includes_header_name = f"{self.target.name()}_includes.h" result = {} result[ "device_functions-generated.h" - ] = f'#include "{device_functions_header_name}"' + ] = f'#include "{device_functions_header_name}"\n#include "{includes_header_name}"' result["model-generated.h"] = self.generate_model() diff --git a/python/aitemplate/backend/profiler_runner.py b/python/aitemplate/backend/profiler_runner.py index a364aa771..9a6cf84b0 100644 --- a/python/aitemplate/backend/profiler_runner.py +++ b/python/aitemplate/backend/profiler_runner.py @@ -309,6 +309,7 @@ def push(self, cmds: List[str], process_result_callback: Callable): future = self._executor.submit( run_task, cmds, self._device_queue, self._dev_select_flag ) + _LOGGER.info(f"The result of profile executor is {future.result()}") # done callbacks are used to collect profiler results for postprocessing # they are launched asynchronously, in a separate thread, diff --git a/python/aitemplate/backend/rocm/conv2d/common.py b/python/aitemplate/backend/rocm/conv2d/common.py index b71a20bce..74913100e 100644 --- a/python/aitemplate/backend/rocm/conv2d/common.py +++ b/python/aitemplate/backend/rocm/conv2d/common.py @@ -104,7 +104,11 @@ HEADER_CODE = jinja2.Template( """ -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#if 1 + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#else + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#endif """ ) diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d.py b/python/aitemplate/backend/rocm/conv2d/conv2d.py index 8c9df0f5f..5192fda2c 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d.py @@ -17,6 +17,7 @@ """ from aitemplate.backend import registry from aitemplate.backend.rocm.conv2d import common +from aitemplate.backend.target import Target # pylint: disable=C0103,C0415,W0613 @@ -39,7 +40,10 @@ def conv2d_config(func_attrs): """ import ck_lib - op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu + if Target.current().get_device_name() == "gfx1100": + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma + else: + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops extra_kind = ck_lib.library.TensorOperation.PassThrough func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind) diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias.py index 91506f2f9..9dad106f5 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias.py @@ -17,6 +17,7 @@ """ from aitemplate.backend import registry from aitemplate.backend.rocm.conv2d import common +from aitemplate.backend.target import Target # pylint: disable=C0103,C0415,W0613 @@ -39,7 +40,10 @@ def conv2d_config(func_attrs): """ import ck_lib - op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu + if Target.current().get_device_name() == "gfx1100": + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma + else: + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops extra_kind = ck_lib.library.TensorOperation.Add func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind) diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py index 3c8f0e3ba..60effd69d 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py @@ -17,6 +17,7 @@ """ from ... import registry from . import common +from aitemplate.backend.target import Target # pylint: disable=C0103,C0415,W0613,C0301 @@ -25,7 +26,10 @@ def conv2d_config(func_attrs): import ck_lib - op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu + if Target.current().get_device_name() == "gfx1100": + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma + else: + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops extra_kind = ck_lib.library.TensorOperation.AddAdd func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind) diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py index 190f85694..fcdf2a16e 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py @@ -19,12 +19,17 @@ from aitemplate.backend import registry from aitemplate.backend.rocm.conv2d import common +from aitemplate.backend.target import Target # pylint: disable=C0103,C0415,W0613 EXTRA_CODE = jinja2.Template( """ -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#if 1 + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#else + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#endif namespace ck { @@ -65,7 +70,10 @@ def conv2d_config(func_attrs): """ import ck_lib - op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu + if Target.current().get_device_name() == "gfx1100": + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma + else: + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops extra_kind = ck_lib.library.TensorOperation.AddAddRelu func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind) diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py index b33561394..2be64e607 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py @@ -17,6 +17,7 @@ """ from aitemplate.backend import registry from aitemplate.backend.rocm.conv2d import common +from aitemplate.backend.target import Target # pylint: disable=C0103,C0415,W0613 @@ -40,7 +41,10 @@ def conv2d_config(func_attrs): """ import ck_lib - op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu + if Target.current().get_device_name() == "gfx1100": + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma + else: + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops extra_kind = ck_lib.library.TensorOperation.AddRelu func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind) diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py index f43e42317..aa3ca199c 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py @@ -19,6 +19,7 @@ from aitemplate.backend import registry from aitemplate.backend.rocm.conv2d import common +from aitemplate.backend.target import Target # pylint: disable=C0103,C0415,W0613 @@ -85,7 +86,10 @@ def conv2d_config(func_attrs): """ import ck_lib - op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu + if Target.current().get_device_name() == "gfx1100": + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluWmma + else: + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasReluXlops extra_kind = ck_lib.library.TensorOperation.AddSigmoid func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind) diff --git a/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py b/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py index 10337922c..ece3f6c76 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py @@ -76,7 +76,11 @@ EXTRA_HEADER_TEMPLATE = jinja2.Template( """ #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#if 1 +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#else #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#endif """ ) diff --git a/python/aitemplate/backend/rocm/gemm/common.py b/python/aitemplate/backend/rocm/gemm/common.py index 857029d87..8cc8ccdae 100644 --- a/python/aitemplate/backend/rocm/gemm/common.py +++ b/python/aitemplate/backend/rocm/gemm/common.py @@ -89,16 +89,34 @@ EXTRA_HEADER_TEMPLATE = jinja2.Template( """ {% if gemm_flag == "" %} + {% if rocm_device_name == "gfx1100" %} +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" + {% else %} #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + {% endif %} {% elif gemm_flag == "permute_m2n3" %} + {% if rocm_device_name == "gfx1100" %} +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp" + {% else %} #include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp" + {% endif %} {% elif "bias" in gemm_flag or has_d0 %} + {% if rocm_device_name == "gfx1100" %} +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" + {% else %} #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" + {% endif %} {% if gemm_flag == "bias_permute" %} + {% if rocm_device_name != "gfx1100" %} #include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp" #include "ck/tensor_operation/gpu/device/impl/gemm_specialization.hpp" + {% endif %} {% elif gemm_flag in ["bias_permute_m2n3", "bias_permute_m3n2"] %} + {% if rocm_device_name == "gfx1100" %} +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp" + {% else %} #include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp" + {% endif %} {% endif %} {% endif %} """ @@ -652,6 +670,7 @@ def gen_profiler( file_pairs = [] has_d0_flag = has_d0(func_attrs) has_d1_flag = has_d1(func_attrs) + rocm_device_name = Target.current().get_device_name() for op_name, op in op_instance.items(): config = emit_instance(op) @@ -672,7 +691,7 @@ def gen_profiler( is_profiler=True, ) extra_header = extra_header_template.render( - gemm_flag=gemm_flag, has_d0=has_d0_flag + rocm_device_name=rocm_device_name, gemm_flag=gemm_flag, has_d0=has_d0_flag ) op_func = SRC_TEMPLATE.render( instances=instance, @@ -786,6 +805,8 @@ def gen_function( instance_decl = "" has_d0_flag = has_d0(func_attrs) has_d1_flag = has_d1(func_attrs) + rocm_device_name = Target.current().get_device_name() + for key, value in exec_path.items(): fname = "f" + sha1(key.encode()).hexdigest() algo = value.algo @@ -829,7 +850,7 @@ def gen_function( exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program) exec_paths += exec_inst extra_header = extra_header_template.render( - gemm_flag=gemm_flag, has_d0=has_d0(func_attrs) + rocm_device_name=rocm_device_name, gemm_flag=gemm_flag, has_d0=has_d0(func_attrs) ) pdims = len(func_attrs["shape"]) if func_attrs.get("shape") is not None else 0 return SRC_TEMPLATE.render( diff --git a/python/aitemplate/backend/rocm/target_def.py b/python/aitemplate/backend/rocm/target_def.py index cb8529f31..82f5764a4 100644 --- a/python/aitemplate/backend/rocm/target_def.py +++ b/python/aitemplate/backend/rocm/target_def.py @@ -121,6 +121,8 @@ def _build_compile_options(self): "-fvisibility=hidden", "-std=c++17", "-w", + "-mcumode", + "-mno-wavefrontsize64", "-DCK_TIME_KERNEL=0", "-Xclang -mlink-builtin-bitcode -Xclang {}/amdgcn/bitcode/oclc_abi_version_400.bc".format( self._pkg_path() @@ -132,6 +134,9 @@ def _build_compile_options(self): elif self._arch in {"GFX90a", "gfx90a"}: options.append("-DCK_AMD_GPU_GFX90A") options.append("--offload-arch=gfx90a") + elif self._arch in {"GFX1100", "gfx1100"}: + options.append("-DCK_AMD_GPU_GFX1100") + options.append("--offload-arch=gfx1100") else: raise RuntimeError("Unsupported GPU Arch") for path in ck_paths: @@ -297,6 +302,8 @@ def _build_compile_options(self): "-fvisibility=hidden", "-std=c++17", "-w", + "-mcumode", + "-mno-wavefrontsize64", "-DCK_TIME_KERNEL=0", "--hip-version=5.2.0", ] @@ -310,6 +317,9 @@ def _build_compile_options(self): elif self._arch in {"GFX90a", "gfx90a"}: options.append("-DCK_AMD_GPU_GFX90A") options.append("--cuda-gpu-arch=gfx90a") + elif self._arch in {"GFX1100", "gfx1100"}: + options.append("-DCK_AMD_GPU_GFX1100") + options.append("--amdgpu-target=gfx1100") else: raise RuntimeError("Unsupported GPU Arch") for path in ck_paths: diff --git a/python/aitemplate/backend/target.py b/python/aitemplate/backend/target.py index a464fddbd..33a1712f0 100644 --- a/python/aitemplate/backend/target.py +++ b/python/aitemplate/backend/target.py @@ -59,6 +59,7 @@ def __init__(self, static_files_path: str): Absolute path to the AIT static/ directory """ self._target_type = -1 + self._device_name = "" self._template_path = "" self._compile_cmd = "" self._cache_path = "" @@ -84,7 +85,7 @@ def __enter__(self): self._load_profile_cache() global CURRENT_TARGET if CURRENT_TARGET is not None: - raise RuntimeError("Target has been set.") + raise RuntimeError(f"Target has been set {CURRENT_TARGET}") assert self._target_type > 0 CURRENT_TARGET = self @@ -138,6 +139,22 @@ def name(self) -> str: """ return TargetType(self._target_type).name + def get_device_name(self) -> str: + """Return the device name of the target. + + Returns + ------- + str + The device name of the target. + """ + from ..testing.detect_target import _detect_cuda, _detect_rocm + + if self.name() == "rocm": + self._device_name = _detect_rocm() + else: + self._device_name = _detect_cuda() + return self._device_name + def cc(self): """Compiler for this target. diff --git a/python/aitemplate/testing/detect_target.py b/python/aitemplate/testing/detect_target.py index 2b2913d6f..ff52cb6d1 100644 --- a/python/aitemplate/testing/detect_target.py +++ b/python/aitemplate/testing/detect_target.py @@ -88,6 +88,8 @@ def _detect_rocm(): proc = Popen(["rocminfo"], stdout=PIPE, stderr=PIPE) stdout, stderr = proc.communicate() stdout = stdout.decode("utf-8") + if "gfx1100" in stdout: + return "gfx1100" if "gfx90a" in stdout: return "gfx90a" if "gfx908" in stdout: diff --git a/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py b/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py index 4c46deeb2..5649ece82 100644 --- a/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py +++ b/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py @@ -48,24 +48,26 @@ class Conv2DSpecialization(enum.Enum): } -class XdlOpType(enum.Enum): +class OpType(enum.Enum): DeviceConv2d_Xdl_CShuffle = auto() DeviceConv2d_Xdl_CShuffle_Bias_Relu = auto() DeviceConv2d_Xdl_CShuffle_Bias_Relu_Add = auto() DeviceConv2d_Xdl_CShuffle_Bias_Sigmoid = auto() DeviceGroupedConv2D_Xdl_CShuffle_Bias_Relu = auto() + DeviceGroupedConv2D_Wmma_CShuffle_Bias_Relu = auto() DeviceConvNdBwdDataNwcKxcNwk_Xdl = auto() DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = auto() -XdlOpTag = { - XdlOpType.DeviceConv2d_Xdl_CShuffle: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", - XdlOpType.DeviceConv2d_Xdl_CShuffle_Bias_Relu: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", - XdlOpType.DeviceConv2d_Xdl_CShuffle_Bias_Relu_Add: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", - XdlOpType.DeviceConv2d_Xdl_CShuffle_Bias_Sigmoid: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", - XdlOpType.DeviceGroupedConv2D_Xdl_CShuffle_Bias_Relu: "ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle", - XdlOpType.DeviceConvNdBwdDataNwcKxcNwk_Xdl: "ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Xdl", - XdlOpType.DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: "ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1", +OpTag = { + OpType.DeviceConv2d_Xdl_CShuffle: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", + OpType.DeviceConv2d_Xdl_CShuffle_Bias_Relu: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", + OpType.DeviceConv2d_Xdl_CShuffle_Bias_Relu_Add: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", + OpType.DeviceConv2d_Xdl_CShuffle_Bias_Sigmoid: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", + OpType.DeviceGroupedConv2D_Xdl_CShuffle_Bias_Relu: "ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle", + OpType.DeviceGroupedConv2D_Wmma_CShuffle_Bias_Relu: "ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", + OpType.DeviceConvNdBwdDataNwcKxcNwk_Xdl: "ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Xdl", + OpType.DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: "ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1", } @@ -122,7 +124,8 @@ def emit(self) -> str: template = jinja2.Template( """ {%for key, value in param.items() %} - {{value}}, // {{key}} +{% if value!=0 %} {{value}}, // {{key}} + {% endif %} {% endfor %} """, trim_blocks=True, @@ -228,7 +231,7 @@ def emit(self) -> str: class Conv2DOperation: operation_kind: library.Conv2dKind extra_kind: library.TensorOperation - xdl_op_type: XdlOpType + op_type: OpType A: library.TensorDesc B: library.TensorDesc C: library.TensorDesc @@ -268,9 +271,9 @@ def accumulator_type(self): def emit(self) -> str: template = jinja2.Template( """ -using {{name}} = {{xdl_op_type}}< +using {{name}} = {{op_type}}< 2, // NDimSpatial -{% if "DeviceConvNdBwdDataNwcKxcNwk_Xdl" in xdl_op_type %} +{% if "DeviceConvNdBwdDataNwcKxcNwk_Xdl" in op_type %} {{ADType}}, // InDataType {{BDType}}, // WeiDataType {{CDType}}, // OutDataType @@ -287,7 +290,7 @@ def emit(self) -> str: {% elif func in ["AA", "AAR"] %} ck::Tuple<{{OutLayout}}, {{OutLayout}}>, // BiasLayout {% else %} -{% if "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" in xdl_op_type %} +{% if "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" in op_type %} ck::Tuple, // BiasLayouts {% else %} ck::Tuple<{{OutLayout}}>, // BiasLayout @@ -312,7 +315,7 @@ def emit(self) -> str: {{epilogue_functor}}, {{Conv2DSpecialization}}, -{% if "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" in xdl_op_type %} +{% if "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" in op_type %} true, true, 1, @@ -323,7 +326,7 @@ def emit(self) -> str: {{tile_config}} {{a_block_transfer}} {{b_block_transfer}} -{% if "DeviceConvNdBwdDataNwcKxcNwk_Xdl" in xdl_op_type %} +{% if "DeviceConvNdBwdDataNwcKxcNwk_Xdl" in op_type %} 7, // CThreadTransferSrcDstVectorDim 1 // GemmCThreadTransferDstScalarPerVector {% else %} @@ -334,7 +337,7 @@ def emit(self) -> str: ) return template.render( name=self.__str__(), - xdl_op_type=XdlOpTag[self.xdl_op_type], + op_type=OpTag[self.op_type], InLayout=library.LayoutTag[self.A.layout], WeiLayout=library.LayoutTag[self.B.layout], OutLayout=library.LayoutTag[self.C.layout], @@ -364,7 +367,7 @@ def emit(self) -> str: Conv2DOp = Conv2DOperation( operation_kind=library.Conv2dKind.Conv2d, extra_kind=library.TensorOperation.PassThrough, - xdl_op_type=XdlOpType.DeviceConv2d_Xdl_CShuffle, + op_type=OpType.DeviceConv2d_Xdl_CShuffle, A=A, B=B, C=C, diff --git a/python/aitemplate/utils/mk_ck_lib/gemm_operation.py b/python/aitemplate/utils/mk_ck_lib/gemm_operation.py index dc1557a5b..be4089e8d 100644 --- a/python/aitemplate/utils/mk_ck_lib/gemm_operation.py +++ b/python/aitemplate/utils/mk_ck_lib/gemm_operation.py @@ -42,28 +42,36 @@ class GemmSpecialization(enum.Enum): } -class XdlOpType(enum.Enum): +class OpType(enum.Enum): DeviceGemmXdl_CShuffle = 1 # TODO: This sucks + DeviceGemmWmma_CShuffle = auto() DeviceGemmMultipleD_Xdl_CShuffle = auto() + DeviceGemmMultipleD_Wmma_CShuffle = auto() DeviceBatchedGemmXdl = auto() DeviceBatchedGemmCPermuteXdl = auto() DeviceGemmBiasCPermute_Xdl = auto() DeviceBatchedContractionMultipleD_Xdl_CShuffle = auto() + DeviceBatchedContractionMultipleD_Wmma_CShuffle = auto() DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle = auto() DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle = auto() + DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle = auto() DeviceBatchedGemmMultiD_Xdl = auto() -XdlOpTag = { - XdlOpType.DeviceGemmXdl_CShuffle: "ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle", - XdlOpType.DeviceGemmMultipleD_Xdl_CShuffle: "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle", - XdlOpType.DeviceBatchedGemmXdl: "ck::tensor_operation::device::DeviceBatchedGemmXdl", - XdlOpType.DeviceBatchedGemmCPermuteXdl: "ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl", - XdlOpType.DeviceGemmBiasCPermute_Xdl: "ck::tensor_operation::device::DeviceGemmBiasEPermute_Xdl", - XdlOpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Xdl_CShuffle", - XdlOpType.DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle", - XdlOpType.DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle", - XdlOpType.DeviceBatchedGemmMultiD_Xdl: "ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl", +OpTag = { + OpType.DeviceGemmXdl_CShuffle: "ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle", + OpType.DeviceGemmWmma_CShuffle: "ck::tensor_operation::device::DeviceGemmWmma_CShuffle", + OpType.DeviceGemmMultipleD_Xdl_CShuffle: "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle", + OpType.DeviceGemmMultipleD_Wmma_CShuffle: "ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle", + OpType.DeviceBatchedGemmXdl: "ck::tensor_operation::device::DeviceBatchedGemmXdl", + OpType.DeviceBatchedGemmCPermuteXdl: "ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl", + OpType.DeviceGemmBiasCPermute_Xdl: "ck::tensor_operation::device::DeviceGemmBiasEPermute_Xdl", + OpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Xdl_CShuffle", + OpType.DeviceBatchedContractionMultipleD_Wmma_CShuffle: "ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle", + OpType.DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle", + OpType.DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle", + OpType.DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle: "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle", + OpType.DeviceBatchedGemmMultiD_Xdl: "ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl", } @@ -286,7 +294,7 @@ def emit(self) -> str: class GemmOperation: operation_kind: library.OperationKind extra_kind: library.TensorOperation - xdl_op_type: XdlOpType + op_type: OpType A: library.TensorDesc B: library.TensorDesc C: library.TensorDesc @@ -338,8 +346,9 @@ def accumulator_type(self): def emit(self) -> str: template = jinja2.Template( """ -using {{name}} = {{xdl_op_type}}< -{% if xdl_op_type_value==1 %} +using {{name}} = {{op_type}}< +// DeviceGemmXdl_CShuffle + DeviceGemmWmma_CShuffle +{% if op_type_value in [1, 2] %} {{ALayout}}, {{BLayout}}, {{CLayout}}, @@ -348,7 +357,8 @@ def emit(self) -> str: {{CDType}}, {{AccDType}}, {{CShuffleDType}}, -{% elif xdl_op_type_value==2 %} +// DeviceGemmMultipleD_Xdl_CShuffle + DeviceGemmMultipleD_Wmma_CShuffle +{% elif op_type_value in [3, 4] %} {{ALayout}}, {{BLayout}}, ck::Tuple<{{DsLayout}}>, // DsLayout @@ -359,7 +369,8 @@ def emit(self) -> str: {{CShuffleDType}}, ck::Tuple<{{DsDType}}>, // DsType {{CDType}}, -{% elif xdl_op_type_value==3 %} +// DeviceBatchedGemmXdl +{% elif op_type_value == 5 %} {{ADType}}, {{BDType}}, {{CDType}}, @@ -367,7 +378,8 @@ def emit(self) -> str: {{ALayout}}, {{BLayout}}, {{CLayout}}, -{% elif xdl_op_type_value==4 %} +// DeviceBatchedGemmCPermuteXdl +{% elif xdl_op_type_value == 6 %} {{ALayout}}, {{BLayout}}, {{CLayout}}, @@ -376,7 +388,8 @@ def emit(self) -> str: {{AccDType}}, {{CShuffleDType}}, {{CDType}}, -{% elif xdl_op_type_value==5 %} +// DeviceGemmBiasCPermute_Xdl +{% elif op_type_value == 7 %} {{ALayout}}, {{BLayout}}, {{CLayout}}, @@ -386,7 +399,8 @@ def emit(self) -> str: float, // CShuffleDType ck::half_t, ck::half_t, -{% elif xdl_op_type_value==6 %} +// DeviceBatchedContractionMultipleD_Xdl_CShuffle + DeviceBatchedContractionMultipleD_Wmma_CShuffle +{% elif op_type_value in [8, 9] %} {% if gemm_kind == "gemm_permute_m2n3" %} 1, 2, 3, 1, // permute m2n3 {% elif gemm_kind == "gemm_permute_m3n2" %} @@ -402,7 +416,8 @@ def emit(self) -> str: ck::Tuple, {% endif %} ck::half_t, -{% elif xdl_op_type_value == 7 %} +// DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle +{% elif op_type_value == 10 %} {{ALayout}}, {{BLayout}}, {{CLayout}}, @@ -413,7 +428,8 @@ def emit(self) -> str: {{CDType}}, {{AccDType}}, float, // CShuffleDType, -{% elif xdl_op_type_value == 8 %} +// DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle +{% elif op_type_value == 11 %} 2, 1, 1, 1, 1, {{ADType}}, {{BDType}}, @@ -423,7 +439,20 @@ def emit(self) -> str: ck::Tuple<>, {{AccDType}}, float, // CShuffleDType, -{% elif xdl_op_type_value == 9 %} +// DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle +{% elif op_type_value == 12 %} + 2, 1, 1, 1, 1, + {{ADType}}, + {{BDType}}, + {{BDType}}, + {{CDType}}, + ck::Tuple<>, + {{AccDType}}, + ck::Tuple<>, + {{AccDType}}, + float, // CShuffleDType, +// DeviceBatchedGemmMultiD_Xdl +{% elif op_type_value == 13 %} {{ALayout}}, {{BLayout}}, ck::Tuple<{{DsLayout}}>, // DsLayout @@ -435,7 +464,9 @@ def emit(self) -> str: ck::Tuple<{{DsDType}}>, // DsType {{EDType}}, {% endif %} -{% if xdl_op_type_value in [7, 8] %} +// DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle +// DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle +{% if op_type_value in [10, 11, 12] %} {{A_elem_op}}, {{B_elem_op}}, ck::tensor_operation::element_wise::ScaleAndResetNaNToMinusInfinity, @@ -444,13 +475,16 @@ def emit(self) -> str: {% endif %} {{B_elem_op}}, {{C_elem_op}}, -{% if xdl_op_type_value!=3 %} +// DeviceBatchedGemmXdl +{% if op_type_value != 5 %} {{GemmSpecialization}}, - {% if xdl_op_type_value==6 %} + // DeviceBatchedContractionMultipleD_Xdl_CShuffle + DeviceBatchedContractionMultipleD_Wmma_CShuffle + {% if op_type_value in [8, 9] %} ck::tensor_operation::device::TensorSpecialization::Packed, ck::tensor_operation::device::TensorSpecialization::Packed, ck::tensor_operation::device::TensorSpecialization::Default, - {% elif xdl_op_type_value==8 %} + // DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle + DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle + {% elif op_type_value in [11, 12] %} ck::tensor_operation::device::TensorSpecialization::Default, ck::tensor_operation::device::TensorSpecialization::Default, ck::tensor_operation::device::TensorSpecialization::Default, @@ -461,10 +495,10 @@ def emit(self) -> str: {{tile_config}} {{a_block_transfer}} {{b_block_transfer}} -{% if xdl_op_type_value in [7, 8] %} +{% if op_type_value in [10, 11, 12] %} // DeviceBatchedGemmSoftmaxGemm {{b1_block_transfer}} {% endif %} -{% if xdl_op_type_value!=3 %} +{% if op_type_value != 5 %} // DeviceBatchedGemmXdl {{c_block_transfer}} {% else %} 7, // src_dst_vector_dim @@ -476,8 +510,8 @@ def emit(self) -> str: return template.render( name=self.__str__(), gemm_kind=library.GemmKindNames[self.operation_kind], - xdl_op_type=XdlOpTag[self.xdl_op_type], - xdl_op_type_value=self.xdl_op_type.value, # This sucks + op_type=OpTag[self.op_type], + op_type_value=self.op_type.value, # This sucks ALayout=library.LayoutTag[self.A.layout], BLayout=library.LayoutTag[self.B.layout], CLayout=library.LayoutTag[self.C.layout], @@ -523,7 +557,7 @@ def emit(self) -> str: GemmOp = GemmOperation( operation_kind=library.GemmKind.BatchGemmPermute, extra_kind=library.TensorOperation.PassThrough, - xdl_op_type=XdlOpType.DeviceBatchedGemmCPermuteXdl, + op_type=OpType.DeviceBatchedGemmCPermuteXdl, A=A, B=B, C=C, diff --git a/python/aitemplate/utils/mk_ck_lib/generator.py b/python/aitemplate/utils/mk_ck_lib/generator.py index e8f89f666..7ccb3e7c7 100644 --- a/python/aitemplate/utils/mk_ck_lib/generator.py +++ b/python/aitemplate/utils/mk_ck_lib/generator.py @@ -23,6 +23,7 @@ library, softmax_operation as softmax, ) +from aitemplate.backend.target import Target ########################################################################################################### @@ -40,45 +41,106 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o in_element_op = library.TensorOperation.PassThrough - tile_descriptions = [ - conv.GroupTileDesc(1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2), - conv.GroupTileDesc(1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4), - conv.GroupTileDesc(1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2), - conv.GroupTileDesc(1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2), - conv.GroupTileDesc(1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2), - conv.GroupTileDesc(1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2), - conv.GroupTileDesc(1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2), - conv.GroupTileDesc(1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1), - conv.GroupTileDesc(1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2), - conv.GroupTileDesc(1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2), - conv.GroupTileDesc(1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1), - conv.GroupTileDesc(1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2), - ] + if Target.current().get_device_name() == "gfx1100": + tile_descriptions = [ + conv.GroupTileDesc(1, 256, 256, 128, 64, 8, 0, 16, 16, 8, 2), + conv.GroupTileDesc(1, 256, 256, 128, 32, 8, 0, 16, 16, 8, 2), + conv.GroupTileDesc(1, 256, 256, 64, 64, 8, 0, 16, 16, 8, 1), + conv.GroupTileDesc(1, 256, 128, 256, 32, 8, 0, 16, 16, 4, 4), + conv.GroupTileDesc(1, 256, 128, 128, 64, 8, 0, 16, 16, 4, 2), + conv.GroupTileDesc(1, 256, 64, 256, 64, 8, 0, 16, 16, 2, 4), + + conv.GroupTileDesc(1, 256, 128, 160, 64, 8, 0, 16, 16, 2, 5), + + conv.GroupTileDesc(1, 128, 128, 64, 64, 8, 0, 16, 16, 4, 2), + conv.GroupTileDesc(1, 128, 64, 128, 64, 8, 0, 16, 16, 2, 4), + conv.GroupTileDesc(1, 128, 64, 64, 64, 8, 0, 16, 16, 2, 2), + conv.GroupTileDesc(1, 128, 64, 32, 64, 8, 0, 16, 16, 2, 1), + conv.GroupTileDesc(1, 128, 32, 128, 64, 8, 0, 16, 16, 1, 4), + conv.GroupTileDesc(1, 128, 32, 64, 64, 8, 0, 16, 16, 1, 2), + + conv.GroupTileDesc(1, 128, 64, 80, 64, 8, 0, 16, 16, 1, 5), + + conv.GroupTileDesc(1, 64, 32, 32, 64, 8, 0, 16, 16, 2, 1), + conv.GroupTileDesc(1, 64, 16, 64, 64, 8, 0, 16, 16, 1, 2), + + conv.GroupTileDesc(1, 64, 32, 80, 64, 8, 0, 16, 16, 1, 5), + + conv.GroupTileDesc(1, 32, 16, 32, 64, 8, 0, 16, 16, 1, 2), + conv.GroupTileDesc(1, 32, 16, 16, 64, 8, 0, 16, 16, 1, 1), + ] + c_block_descriptions = [ + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 64, 1, 4], 8), + + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 64, 1, 2], 8), + + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 2], 8), + + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 2], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 2], 8), + ] + else: + tile_descriptions = [ + conv.GroupTileDesc(1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + conv.GroupTileDesc(1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + conv.GroupTileDesc(1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + conv.GroupTileDesc(1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + conv.GroupTileDesc(1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + conv.GroupTileDesc(1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + conv.GroupTileDesc(1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2), + conv.GroupTileDesc(1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + conv.GroupTileDesc(1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + conv.GroupTileDesc(1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2), + conv.GroupTileDesc(1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1), + conv.GroupTileDesc(1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + ] - c_block_descriptions = [ - conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), - conv.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), - conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), - conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), - ] + c_block_descriptions = [ + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + ] block_descriptions = [] for t in tile_descriptions: block_transfer = -1 if t.block_size == 256: - block_transfer = [4, 64, 1] + if t.n_per_block % 80 == 0: + block_transfer = [8, 32, 1] + else: + block_transfer = [4, 64, 1] if t.block_size == 128: - block_transfer = [4, 32, 1] + if t.n_per_block % 80 == 0: + block_transfer = [8, 16, 1] + else: + block_transfer = [4, 32, 1] if t.block_size == 64: block_transfer = [4, 16, 1] + if t.block_size == 32: + block_transfer = [2, 16, 1] assert ( block_transfer != -1 and "Cannot determine block_transfer_size with block_size " @@ -94,10 +156,16 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o conv.Conv2DSpecialization.ConvFwd1x1S1P0, ] - gemm_specialization = [ - conv.Conv2DSpecialization.GemmDefault, - conv.Conv2DSpecialization.MNKPadding, - ] + if Target.current().get_device_name() == "gfx1100": + gemm_specialization = [ + # conv.Conv2DSpecialization.GemmDefault, Have unknown issue with OddC + conv.Conv2DSpecialization.MNKPadding, + ] + else: + gemm_specialization = [ + conv.Conv2DSpecialization.GemmDefault, + conv.Conv2DSpecialization.MNKPadding, + ] operations = [] for conv2d_spec in conv2d_specialization: @@ -108,7 +176,7 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o new_operation = conv.Conv2DOperation( operation_kind=operation_kind, extra_kind=out_element_op, - xdl_op_type=conv.XdlOpType(operation_kind.value), + op_type=conv.OpType(operation_kind.value), A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -128,22 +196,32 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o conv2d_specialization = [conv.Conv2DSpecialization.ConvFwdOddC] - tile_descriptions += [ - conv.GroupTileDesc(1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1), - conv.GroupTileDesc(1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1), - conv.GroupTileDesc(1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1), - conv.GroupTileDesc(1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2), - conv.GroupTileDesc(1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2), - conv.GroupTileDesc(1, 256, 256, 16, 32, 8, 8, 16, 16, 4, 1), # c_out=1 - ] + if Target.current().get_device_name() == "gfx1100": + tile_descriptions = [ + conv.GroupTileDesc(1, 256, 128, 64, 32, 8, 0, 16, 16, 4, 1), + conv.GroupTileDesc(1, 256, 128, 64, 64, 8, 0, 16, 16, 4, 1), + conv.GroupTileDesc(1, 256, 256, 64, 32, 8, 0, 16, 16, 8, 1), + # conv.GroupTileDesc(1, 128, 128, 64, 32, 8, 0, 16, 16, 4, 2), + # conv.GroupTileDesc(1, 128, 64, 64, 32, 8, 0, 16, 16, 2, 2), failed + conv.GroupTileDesc(1, 256, 128, 16, 32, 8, 0, 16, 16, 1, 1), # c_out=1 + ] + else: + tile_descriptions += [ + conv.GroupTileDesc(1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + conv.GroupTileDesc(1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + conv.GroupTileDesc(1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1), + conv.GroupTileDesc(1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + conv.GroupTileDesc(1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2), + conv.GroupTileDesc(1, 256, 256, 16, 32, 8, 8, 16, 16, 4, 1), # c_out=1 + ] block_descriptions = [ conv.BlockTransferDesc([4, 8, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([4, 8, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), - conv.BlockTransferDesc([4, 4, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([4, 8, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), - conv.BlockTransferDesc([4, 4, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), - conv.BlockTransferDesc([4, 4, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + # conv.BlockTransferDesc([4, 4, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + # conv.BlockTransferDesc([4, 4, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), + conv.BlockTransferDesc([4, 8, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([4, 2, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([4, 8, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), conv.BlockTransferDesc([4, 8, 8], [1, 0, 2], [1, 0, 2], 2, 1, 1, 1), @@ -158,13 +236,13 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o conv.BlockTransferDesc([4, 16, 4], [1, 0, 2], [1, 0, 2], 2, 2, 2, 1), # c_out=1 ] - c_block_descriptions += [ + c_block_descriptions = [ conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), conv.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8), - conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), - conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), - conv.CBlockTransferDesc(4, 1, [1, 256, 1, 1], 1), # c_out=1 + # conv.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8), + # conv.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8), + conv.CBlockTransferDesc(1, 1, [1, 16, 1, 16], 1), # c_out=1 ] for conv2d_spec in conv2d_specialization: for gemm_spec in gemm_specialization: @@ -174,7 +252,7 @@ def CreateConv2dFwdOperator(manifest, operation_kind, out_element_op, out_data_o new_operation = conv.Conv2DOperation( operation_kind=operation_kind, extra_kind=out_element_op, - xdl_op_type=conv.XdlOpType(operation_kind.value), + op_type=conv.OpType(operation_kind.value), A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -273,7 +351,7 @@ def CreateConv2dBwdOperator(manifest, operation_kind, out_element_op, out_data_o new_operation = conv.Conv2DOperation( operation_kind=operation_kind, extra_kind=out_element_op, - xdl_op_type=conv.XdlOpType(operation_kind.value), + op_type=conv.OpType(operation_kind.value), A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -383,7 +461,7 @@ def CreateConv2dBwdBiasOperator( new_operation = conv.Conv2DOperation( operation_kind=operation_kind, extra_kind=out_element_op, - xdl_op_type=conv.XdlOpType(operation_kind.value), + op_type=conv.OpType(operation_kind.value), A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -418,24 +496,39 @@ def CreateGemmRRROperator(manifest): ) element_op = library.TensorOperation.PassThrough - tile_descriptions = [ - gemm.TileDesc(256, 256, 128, 32, 8, 2, 32, 32, 4, 2), - gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 256, 32, 8, 2, 32, 32, 2, 4), - gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), - gemm.TileDesc(128, 128, 128, 32, 8, 2, 32, 32, 4, 2), - gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 128, 32, 8, 2, 32, 32, 2, 2), - gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(128, 128, 64, 32, 8, 2, 32, 32, 2, 2), - gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(128, 64, 128, 32, 8, 2, 32, 32, 2, 2), - gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(256, 128, 64, 32, 8, 2, 32, 32, 2, 1), - gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(256, 64, 128, 32, 8, 2, 32, 32, 1, 2), - gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), - ] + if Target.current().get_device_name() == "gfx1100": + tile_descriptions = [ + gemm.TileDesc(256, 128, 256, 64, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 64, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 32, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 128, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 256, 64, 64, 8, 0, 16, 16, 8, 1), + gemm.TileDesc(256, 64, 256, 64, 8, 0, 16, 16, 2, 4), + gemm.TileDesc(128, 128, 128, 64, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(128, 128, 64, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(128, 64, 128, 64, 8, 0, 16, 16, 4, 2), + ] + else: + tile_descriptions = [ + gemm.TileDesc(256, 256, 128, 32, 8, 2, 32, 32, 4, 2), + gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 2, 32, 32, 2, 4), + gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + gemm.TileDesc(128, 128, 128, 32, 8, 2, 32, 32, 4, 2), + gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 2, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 8, 2, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 2, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 64, 32, 8, 2, 32, 32, 2, 1), + gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 32, 8, 2, 32, 32, 1, 2), + gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + ] b_block_descriptions = [ gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0), @@ -486,6 +579,12 @@ def CreateGemmRRROperator(manifest): gemm.GemmSpecialization.MNKPadding, ] operations = [] + + if Target.current().get_device_name() == "gfx1100": + op_type = gemm.OpType.DeviceGemmWmma_CShuffle + else: + op_type = gemm.OpType.DeviceGemmXdl_CShuffle + for gemm_spec in gemm_specialization: for tile_desc, a_block_desc, b_block_desc, c_block_desc in zip( tile_descriptions, @@ -496,7 +595,7 @@ def CreateGemmRRROperator(manifest): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=gemm.XdlOpType.DeviceGemmXdl_CShuffle, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -527,21 +626,40 @@ def CreateGemmRCROperator(manifest): ) element_op = library.TensorOperation.PassThrough - tile_descriptions = [ - gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), - gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), - gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), - gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), - ] + if Target.current().get_device_name() == "gfx1100": + tile_descriptions = [ + gemm.TileDesc(256, 128, 256, 64, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 64, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 32, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 128, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 256, 64, 64, 8, 0, 16, 16, 8, 1), + gemm.TileDesc(256, 64, 256, 64, 8, 0, 16, 16, 2, 4), + gemm.TileDesc(128, 128, 128, 64, 8, 0, 16, 16, 8, 2), # failed + gemm.TileDesc(128, 128, 64, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(128, 64, 128, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(64, 16, 64, 64, 8, 0, 16, 16, 1, 2), + gemm.TileDesc(64, 16, 128, 64, 8, 0, 16, 16, 1, 4), + gemm.TileDesc(64, 32, 64, 32, 8, 0, 16, 16, 2, 2), + gemm.TileDesc(64, 64, 32, 64, 8, 0, 16, 16, 4, 1), + ] + else: + tile_descriptions = [ + gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + ] block_descriptions = [] c_block_descriptions = [] @@ -576,6 +694,12 @@ def CreateGemmRCROperator(manifest): gemm.GemmSpecialization.MNKPadding, ] operations = [] + + if Target.current().get_device_name() == "gfx1100": + op_type = gemm.OpType.DeviceGemmWmma_CShuffle + else: + op_type = gemm.OpType.DeviceGemmXdl_CShuffle + for gemm_spec in gemm_specialization: for tile_desc, block_desc, c_block_desc in zip( tile_descriptions, block_descriptions, c_block_descriptions @@ -583,7 +707,7 @@ def CreateGemmRCROperator(manifest): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=gemm.XdlOpType.DeviceGemmXdl_CShuffle, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -601,7 +725,7 @@ def CreateGemmRCROperator(manifest): return operations -def CreateGemmRCRBilinearOperator(manifest, c_element_op): +def CreateGemmRCRBillinearOperator(manifest, c_element_op): operation_kind = library.GemmKind.Gemm a_element_desc = library.TensorDesc( library.DataType.f16, library.LayoutType.RowMajor @@ -639,21 +763,47 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): e_dtype = library.DataType.f16 element_op = library.TensorOperation.PassThrough # 0 indicates not print - tile_descriptions = [ - gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), - gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), - gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), - gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), - ] + if Target.current().get_device_name() == "gfx1100": + tile_descriptions = [ + gemm.TileDesc(256, 128, 256, 64, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 32, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 128, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 0, 16, 16, 4, 2), + + gemm.TileDesc(256, 128, 160, 64, 8, 0, 16, 16, 2, 5), + + gemm.TileDesc(128, 128, 128, 32, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(128, 256, 64, 64, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(128, 64, 256, 64, 8, 0, 16, 16, 2, 8), + + gemm.TileDesc(128, 64, 80, 64, 8, 0, 16, 16, 1, 5), + + gemm.TileDesc(64, 16, 64, 64, 8, 0, 16, 16, 1, 2), + gemm.TileDesc(64, 16, 128, 64, 8, 0, 16, 16, 1, 4), + gemm.TileDesc(64, 64, 32, 64, 8, 0, 16, 16, 4, 1), + gemm.TileDesc(64, 32, 64, 64, 8, 0, 16, 16, 2, 2), + + gemm.TileDesc(64, 32, 80, 64, 8, 0, 16, 16, 1, 5), + + gemm.TileDesc(32, 16, 32, 64, 8, 0, 16, 16, 1, 2), + gemm.TileDesc(32, 16, 16, 64, 8, 0, 16, 16, 1, 1), + ] + else: + tile_descriptions = [ + gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + ] block_descriptions = [] c_block_descriptions = [] @@ -661,17 +811,33 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): block_transfer = -1 c_block_transfer = -1 if t.block_size == 256: - block_transfer = [4, 64, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + if t.n_per_block % 80 == 0: + block_transfer = [8, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 64, 1, 4], 8) + else: + block_transfer = [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) if t.block_size == 128: - block_transfer = [4, 32, 1] - if t.n_per_block == 128: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + if t.n_per_block % 80 == 0: + block_transfer = [8, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 64, 1, 2], 8) else: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) + block_transfer = [4, 32, 1] + if t.n_per_block == 128: + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + else: + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) if t.block_size == 64: - block_transfer = [4, 16, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8) + if t.n_per_block % 80 == 0: + block_transfer = [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 2], 8) + else: + block_transfer = [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8) + + if t.block_size == 32: + block_transfer = [2, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 2], 8) assert ( block_transfer != -1 @@ -688,6 +854,12 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): gemm.GemmSpecialization.MNKPadding, ] operations = [] + + if Target.current().get_device_name() == "gfx1100": + op_type = gemm.OpType.DeviceGemmMultipleD_Wmma_CShuffle + else: + op_type = gemm.OpType.DeviceGemmMultipleD_Xdl_CShuffle + for gemm_spec in gemm_specialization: for tile_desc, block_desc, c_block_desc in zip( tile_descriptions, block_descriptions, c_block_descriptions @@ -695,7 +867,7 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceGemmMultipleD_Xdl_CShuffle, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -730,7 +902,7 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceGemmMultipleD_Xdl_CShuffle, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -763,7 +935,7 @@ def CreateGemmRCRBilinearOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceGemmMultipleD_Xdl_CShuffle, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -844,7 +1016,7 @@ def CreateBmmRCROperator(manifest): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmXdl, + op_type=gemm.OpType.DeviceBatchedGemmXdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -933,7 +1105,7 @@ def CreateGemmRCRPermOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceGemmBiasCPermute_Xdl, + op_type=gemm.OpType.DeviceGemmBiasCPermute_Xdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1049,7 +1221,7 @@ def CreateGemmRRRPermOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceGemmBiasCPermute_Xdl, + op_type=gemm.OpType.DeviceGemmBiasCPermute_Xdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1084,21 +1256,51 @@ def CreateGemmRCRm2n3PermOperator(manifest, c_element_op): e_dtype = library.DataType.f16 element_op = library.TensorOperation.PassThrough # 0 indicates not print - tile_descriptions = [ - gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), - gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), - # gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), - gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), - gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), - ] + if Target.current().get_device_name() == "gfx1100": + tile_descriptions = [ + gemm.TileDesc(256, 128, 256, 64, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 64, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 32, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 128, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 256, 64, 64, 8, 0, 16, 16, 8, 1), + gemm.TileDesc(256, 64, 256, 64, 8, 0, 16, 16, 2, 4), + + gemm.TileDesc(256, 128, 160, 64, 8, 0, 16, 16, 2, 5), + + gemm.TileDesc(128, 128, 128, 64, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(128, 128, 64, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(128, 64, 128, 64, 8, 0, 16, 16, 4, 2), + + gemm.TileDesc(128, 64, 80, 64, 8, 0, 16, 16, 1, 5), + + gemm.TileDesc(64, 16, 64, 64, 8, 0, 16, 16, 1, 2), + gemm.TileDesc(64, 16, 128, 64, 8, 0, 16, 16, 1, 4), + gemm.TileDesc(64, 32, 64, 32, 8, 0, 16, 16, 2, 2), + gemm.TileDesc(64, 64, 32, 64, 8, 0, 16, 16, 4, 1), + + gemm.TileDesc(64, 32, 80, 64, 8, 0, 16, 16, 1, 5), + + gemm.TileDesc(32, 16, 32, 64, 8, 0, 16, 16, 1, 2), + gemm.TileDesc(32, 16, 16, 64, 8, 0, 16, 16, 1, 1), + ] + else: + tile_descriptions = [ + gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + # gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + ] block_descriptions = [] c_block_descriptions = [] @@ -1106,19 +1308,32 @@ def CreateGemmRCRm2n3PermOperator(manifest, c_element_op): block_transfer = -1 c_block_transfer = -1 if t.block_size == 256: - block_transfer = [4, 64, 1] - # TODO:figure out the last dimension - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) - if t.block_size == 128: - block_transfer = [4, 32, 1] - if t.n_per_block == 128: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + if t.n_per_block % 80 == 0: + block_transfer = [8, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 64, 1, 4], 8) else: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) + block_transfer = [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + if t.block_size == 128: + if t.n_per_block % 80 == 0: + block_transfer = [8, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 64, 1, 2], 8) + else: + block_transfer = [4, 32, 1] + if t.n_per_block == 128: + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + else: + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) if t.block_size == 64: - block_transfer = [4, 16, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8) - + if t.n_per_block % 80 == 0: + block_transfer = [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 2], 8) + else: + block_transfer = [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8) + if t.block_size == 32: + block_transfer = [2, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 2], 8) assert ( block_transfer != -1 and c_block_transfer != -1 @@ -1134,6 +1349,12 @@ def CreateGemmRCRm2n3PermOperator(manifest, c_element_op): gemm.GemmSpecialization.MNKPadding, ] operations = [] + + if Target.current().get_device_name() == "gfx1100": + op_type = gemm.OpType.DeviceBatchedContractionMultipleD_Wmma_CShuffle + else: + op_type = gemm.OpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle + for gemm_spec in gemm_specialization: for tile_desc, block_desc, c_block_desc in zip( tile_descriptions, block_descriptions, c_block_descriptions @@ -1141,7 +1362,7 @@ def CreateGemmRCRm2n3PermOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1176,21 +1397,51 @@ def CreateGemmRCRm3n2PermOperator(manifest, c_element_op): e_dtype = library.DataType.f16 element_op = library.TensorOperation.PassThrough # 0 indicates not print - tile_descriptions = [ - gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), - gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), - gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), - # gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2), - gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), - gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), - gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), - gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), - ] + if Target.current().get_device_name() == "gfx1100": + tile_descriptions = [ + gemm.TileDesc(256, 128, 256, 64, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 64, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 0, 16, 16, 4, 4), + gemm.TileDesc(256, 256, 128, 32, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(256, 128, 128, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(256, 256, 64, 64, 8, 0, 16, 16, 8, 1), + gemm.TileDesc(256, 64, 256, 64, 8, 0, 16, 16, 2, 4), + + gemm.TileDesc(256, 128, 160, 64, 8, 0, 16, 16, 2, 5), + + gemm.TileDesc(128, 128, 128, 64, 8, 0, 16, 16, 8, 2), + gemm.TileDesc(128, 128, 64, 64, 8, 0, 16, 16, 4, 2), + gemm.TileDesc(128, 64, 128, 64, 8, 0, 16, 16, 4, 2), + + gemm.TileDesc(128, 64, 80, 64, 8, 0, 16, 16, 1, 5), + + gemm.TileDesc(64, 16, 64, 64, 8, 0, 16, 16, 1, 2), + gemm.TileDesc(64, 16, 128, 64, 8, 0, 16, 16, 1, 4), + gemm.TileDesc(64, 32, 64, 32, 8, 0, 16, 16, 2, 2), + gemm.TileDesc(64, 64, 32, 64, 8, 0, 16, 16, 4, 1), + + gemm.TileDesc(64, 32, 80, 64, 8, 0, 16, 16, 1, 5), + + gemm.TileDesc(32, 16, 32, 64, 8, 0, 16, 16, 1, 2), + gemm.TileDesc(32, 16, 16, 64, 8, 0, 16, 16, 1, 1), + ] + else: + tile_descriptions = [ + gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4), + gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2), + gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2), + # gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2), + gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2), + gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1), + gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2), + ] block_descriptions = [] c_block_descriptions = [] @@ -1198,19 +1449,32 @@ def CreateGemmRCRm3n2PermOperator(manifest, c_element_op): block_transfer = -1 c_block_transfer = -1 if t.block_size == 256: - block_transfer = [4, 64, 1] - # TODO:figure out the last dimension - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 1) - if t.block_size == 128: - block_transfer = [4, 32, 1] - if t.n_per_block == 128: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 1) + if t.n_per_block % 80 == 0: + block_transfer = [8, 32, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 64, 1, 4], 8) else: - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 1) + block_transfer = [4, 64, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8) + if t.block_size == 128: + if t.n_per_block % 80 == 0: + block_transfer = [8, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 64, 1, 2], 8) + else: + block_transfer = [4, 32, 1] + if t.n_per_block == 128: + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8) + else: + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8) if t.block_size == 64: - block_transfer = [4, 16, 1] - c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 1) - + if t.n_per_block % 80 == 0: + block_transfer = [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 2], 8) + else: + block_transfer = [4, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 4], 8) + if t.block_size == 32: + block_transfer = [2, 16, 1] + c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 2], 8) assert ( block_transfer != -1 and c_block_transfer != -1 @@ -1226,6 +1490,12 @@ def CreateGemmRCRm3n2PermOperator(manifest, c_element_op): gemm.GemmSpecialization.MNKPadding, ] operations = [] + + if Target.current().get_device_name() == "gfx1100": + op_type = gemm.OpType.DeviceBatchedContractionMultipleD_Wmma_CShuffle + else: + op_type = gemm.OpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle + for gemm_spec in gemm_specialization: for tile_desc, block_desc, c_block_desc in zip( tile_descriptions, block_descriptions, c_block_descriptions @@ -1233,7 +1503,7 @@ def CreateGemmRCRm3n2PermOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedContractionMultipleD_Xdl_CShuffle, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1323,7 +1593,7 @@ def CreateBmmRCRPermOperator(manifest): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmCPermuteXdl, + op_type=gemm.OpType.DeviceBatchedGemmCPermuteXdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1344,7 +1614,7 @@ def CreateBmmRCRPermOperator(manifest): def CreateBmmSoftmaxBmmOperator( manifest, operation_kind=library.GemmKind.BatchGemmSoftmaxGemm, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle, + op_type=gemm.OpType.DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle, ): a_element_desc = library.TensorDesc( library.DataType.f16, library.LayoutType.RowMajor @@ -1425,7 +1695,7 @@ def CreateBmmSoftmaxBmmOperator( new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=xdl_op_type, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1447,7 +1717,7 @@ def CreateBmmSoftmaxBmmOperator( def CreateBmmSoftmaxBmmPermOperator( manifest, operation_kind=library.GemmKind.BatchGemmSoftmaxGemmPermute, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle, + op_type=gemm.OpType.DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle, causal_mask=None, ): a_element_desc = library.TensorDesc( @@ -1460,75 +1730,134 @@ def CreateBmmSoftmaxBmmPermOperator( library.DataType.f16, library.LayoutType.RowMajor ) element_op = library.TensorOperation.PassThrough - tile_descriptions = [ - gemm.AttnTileDesc(256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2), - gemm.AttnTileDesc(256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4), - gemm.AttnTileDesc(256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2), - gemm.AttnTileDesc(256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4), - gemm.AttnTileDesc(256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2), - gemm.AttnTileDesc(256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2), - gemm.AttnTileDesc(256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4), - gemm.AttnTileDesc(256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4), - gemm.AttnTileDesc(256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8), - gemm.AttnTileDesc(256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4), - gemm.AttnTileDesc(256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8), - gemm.AttnTileDesc(256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4), - gemm.AttnTileDesc(256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4), - gemm.AttnTileDesc(256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4), - # for MNKOPadding - gemm.AttnTileDesc(256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4), - gemm.AttnTileDesc(256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4), - ] + if Target.current().get_device_name() == "gfx1100": + op_type = gemm.OpType.DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle + + tile_descriptions = [ + gemm.AttnTileDesc(256, 128, 64, 32, 8, 64, 32, 8, 16, 16, 16, 1, 4, 4), + gemm.AttnTileDesc(256, 128, 64, 64, 8, 64, 64, 8, 16, 16, 16, 1, 4, 4), + gemm.AttnTileDesc(128, 64, 64, 64, 8, 64, 64, 8, 16, 16, 16, 1, 4, 4), + gemm.AttnTileDesc(128, 64, 64, 32, 8, 64, 32, 8, 16, 16, 16, 1, 4, 4), + gemm.AttnTileDesc(128, 64, 64, 64, 8, 32, 32, 8, 16, 16, 16, 1, 4, 2), + gemm.AttnTileDesc(128, 64, 32, 32, 8, 32, 32, 8, 16, 16, 16, 1, 2, 2), + gemm.AttnTileDesc( 64, 32, 64, 64, 8, 64, 64, 8, 16, 16, 16, 1, 4, 4), + gemm.AttnTileDesc( 64, 32, 64, 32, 8, 64, 32, 8, 16, 16, 16, 1, 4, 4), + gemm.AttnTileDesc( 64, 32, 64, 64, 8, 32, 32, 8, 16, 16, 16, 1, 4, 2), + gemm.AttnTileDesc( 64, 32, 32, 32, 8, 32, 32, 8, 16, 16, 16, 1, 2, 2), + ] - block_descriptions = [ - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 0), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 0), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 0), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - # for MNKOPadding - gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 0), - gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), - ] + block_descriptions = [ + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 16, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 16, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 16, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 16, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + ] + + else: + tile_descriptions = [ + gemm.AttnTileDesc(256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2), + gemm.AttnTileDesc(256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4), + gemm.AttnTileDesc(256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2), + gemm.AttnTileDesc(256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4), + gemm.AttnTileDesc(256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2), + gemm.AttnTileDesc(256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2), + gemm.AttnTileDesc(256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4), + gemm.AttnTileDesc(256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4), + gemm.AttnTileDesc(256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8), + gemm.AttnTileDesc(256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4), + gemm.AttnTileDesc(256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8), + gemm.AttnTileDesc(256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4), + gemm.AttnTileDesc(256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4), + gemm.AttnTileDesc(256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4), + # for MNKOPadding + gemm.AttnTileDesc(256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4), + gemm.AttnTileDesc(256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4), + ] + + block_descriptions = [ + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 0), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 0), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 0), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + # for MNKOPadding + gemm.BlockTransferDesc([8, 32, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 0), + gemm.BlockTransferDesc([4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1), + ] causal_mask_flag = 0 if causal_mask is not None: causal_mask_flag = 1 if library.TensorOperationTag[causal_mask] == "True" else 0 c_block_descriptions, b1_block_descriptions = [], [] for i in range(len(tile_descriptions)): - if i in [0, 2, 4, 5, 9, 11]: - block_transfer = [16, 16, 1] - else: - block_transfer = [8, 32, 1] - b1_block_descriptions.append( - gemm.BlockTransferDesc(block_transfer, [0, 2, 1], [0, 2, 1], 1, 4, 2, 0) - ) - - if i in [8, 10]: - c_block_transfer = gemm.MaskedCBlockTransferDesc( - 1, 8, [1, 16, 1, 16], 8, causal_mask_flag + if Target.current().get_device_name() == "gfx1100": + if i <2: + block_transfer = [4, 8, 8] + c_block_transfer = gemm.MaskedCBlockTransferDesc( + 1, 2, [1, 64, 1, 4], 8, causal_mask_flag + ) + elif i <6: + block_transfer = [4, 4, 8] + c_block_transfer = gemm.MaskedCBlockTransferDesc( + 1, 2, [1, 32, 1, 4], 8, causal_mask_flag + ) + else: + block_transfer = [4, 2, 8] + c_block_transfer = gemm.MaskedCBlockTransferDesc( + 1, 2, [1, 16, 1, 4], 8, causal_mask_flag + ) + b1_block_descriptions.append( + gemm.BlockTransferDesc(block_transfer, [0, 2, 1], [0, 2, 1], 1, 8, 1, 0) ) + c_block_descriptions.append(c_block_transfer) else: - c_shuffle = 4 if i in [9, 11] else 2 - c_block_transfer = gemm.MaskedCBlockTransferDesc( - 1, c_shuffle, [1, 32, 1, 8], 8, causal_mask_flag + if i in [0, 2, 4, 5, 9, 11]: + block_transfer = [4, 8, 8] + # block_transfer = [16, 16, 1] + else: + block_transfer = [8, 32, 1] + b1_block_descriptions.append( + gemm.BlockTransferDesc(block_transfer, [0, 2, 1], [0, 2, 1], 1, 8, 1, 0) + # gemm.BlockTransferDesc(block_transfer, [0, 2, 1], [0, 2, 1], 1, 4, 2, 0) ) - c_block_descriptions.append(c_block_transfer) + if i in [8, 10]: + c_block_transfer = gemm.MaskedCBlockTransferDesc( + 1, 8, [1, 16, 1, 16], 8, causal_mask_flag + ) + else: + c_shuffle = 4 if i in [9, 11] else 2 + c_block_transfer = gemm.MaskedCBlockTransferDesc( + 1, c_shuffle, [1, 64, 1, 4], 8, causal_mask_flag + ) + # c_block_transfer = gemm.MaskedCBlockTransferDesc( + # 1, c_shuffle, [1, 32, 1, 8], 8, causal_mask_flag + # ) + + c_block_descriptions.append(c_block_transfer) gemm_specialization = [] for i in range(len(tile_descriptions)): if i < 12: - gemm_specialization.append(gemm.GemmSpecialization.GemmDefault) + if Target.current().get_device_name() == "gfx1100": + gemm_specialization.append(gemm.GemmSpecialization.MNKOPadding) + else: + gemm_specialization.append(gemm.GemmSpecialization.GemmDefault) elif i in [12, 13]: gemm_specialization.append(gemm.GemmSpecialization.MNOPadding) else: @@ -1546,7 +1875,7 @@ def CreateBmmSoftmaxBmmPermOperator( new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=extra_op, - xdl_op_type=xdl_op_type, + op_type=op_type, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1649,7 +1978,7 @@ def CreateBmmRRROperator(manifest): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmXdl, + op_type=gemm.OpType.DeviceBatchedGemmXdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1761,7 +2090,7 @@ def CreateBmmRRRBillinearOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmMultiD_Xdl, + op_type=gemm.OpType.DeviceBatchedGemmMultiD_Xdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1877,7 +2206,7 @@ def CreateBmmCCRBillinearOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmMultiD_Xdl, + op_type=gemm.OpType.DeviceBatchedGemmMultiD_Xdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -1993,7 +2322,7 @@ def CreateBmmCRRBillinearOperator(manifest, c_element_op): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=c_element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmMultiD_Xdl, + op_type=gemm.OpType.DeviceBatchedGemmMultiD_Xdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -2105,7 +2434,7 @@ def CreateBmmRRRPermOperator(manifest): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmCPermuteXdl, + op_type=gemm.OpType.DeviceBatchedGemmCPermuteXdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -2187,7 +2516,7 @@ def CreateBmmCCROperator(manifest): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmXdl, + op_type=gemm.OpType.DeviceBatchedGemmXdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -2260,7 +2589,7 @@ def CreateBmmCRROperator(manifest): new_operation = gemm.GemmOperation( operation_kind=operation_kind, extra_kind=element_op, - xdl_op_type=gemm.XdlOpType.DeviceBatchedGemmXdl, + op_type=gemm.OpType.DeviceBatchedGemmXdl, A=a_element_desc, B=b_element_desc, C=c_element_desc, @@ -2386,50 +2715,111 @@ def CreateGroupNormOperator(manifest, rank=5): def GenerateTensorOp(manifest): # Conv2d - CreateConv2dFwdOperator( - manifest, - library.Conv2dKind.GroupConv2dBiasRelu, - library.TensorOperation.PassThrough, - ) - # Conv2dBias - CreateConv2dFwdOperator( - manifest, - library.Conv2dKind.GroupConv2dBiasRelu, - library.TensorOperation.Add, - library.MemoryDataOperation.MemorySet, - ) - # Conv2dBiasRelu - CreateConv2dFwdOperator( - manifest, - library.Conv2dKind.GroupConv2dBiasRelu, - library.TensorOperation.AddRelu, - library.MemoryDataOperation.MemorySet, - ) - # Conv2dBiasAdd - CreateConv2dFwdOperator( - manifest, - library.Conv2dKind.GroupConv2dBiasRelu, - library.TensorOperation.AddAdd, - ) - # Conv2dBiasReluAdd - CreateConv2dFwdOperator( - manifest, - library.Conv2dKind.GroupConv2dBiasRelu, - library.TensorOperation.AddReluAdd, - ) - # Conv2dBiasAddRelu - CreateConv2dFwdOperator( - manifest, - library.Conv2dKind.GroupConv2dBiasRelu, - library.TensorOperation.AddAddRelu, - ) - # Conv2dBiasSigmoid - CreateConv2dFwdOperator( - manifest, - library.Conv2dKind.GroupConv2dBiasRelu, - library.TensorOperation.AddSigmoid, - library.MemoryDataOperation.MemorySet, - ) + if Target.current().get_device_name() == "gfx1100": + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluWmma, + library.TensorOperation.PassThrough, + ) + # Conv2dBias + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluWmma, + library.TensorOperation.Add, + library.MemoryDataOperation.MemorySet, + ) + # Conv2dBiasRelu + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluWmma, + library.TensorOperation.AddRelu, + library.MemoryDataOperation.MemorySet, + ) + # Conv2dBiasAdd + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluWmma, + library.TensorOperation.AddAdd, + ) + # Conv2dBiasReluAdd + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluWmma, + library.TensorOperation.AddReluAdd, + ) + # Conv2dBiasAddRelu + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluWmma, + library.TensorOperation.AddAddRelu, + ) + # Conv2dBiasSigmoid + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluWmma, + library.TensorOperation.AddSigmoid, + library.MemoryDataOperation.MemorySet, + ) + else: + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluXdlops, + library.TensorOperation.PassThrough, + ) + # Conv2dBias + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluXdlops, + library.TensorOperation.Add, + library.MemoryDataOperation.MemorySet, + ) + # Conv2dBiasRelu + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluXdlops, + library.TensorOperation.AddRelu, + library.MemoryDataOperation.MemorySet, + ) + # Conv2dBiasAdd + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluXdlops, + library.TensorOperation.AddAdd, + ) + # Conv2dBiasReluAdd + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluXdlops, + library.TensorOperation.AddReluAdd, + ) + # Conv2dBiasAddRelu + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluXdlops, + library.TensorOperation.AddAddRelu, + ) + # Conv2dBiasSigmoid + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasReluXdlops, + library.TensorOperation.AddSigmoid, + library.MemoryDataOperation.MemorySet, + ) + # TranposedConv2d + CreateConv2dBwdOperator( + manifest, + library.Conv2dKind.TransposedConv2d, + library.TensorOperation.PassThrough, + library.MemoryDataOperation.MemorySet, + ) + # TranposedConv2dBiasRelu + CreateConv2dBwdBiasOperator( + manifest, + library.Conv2dKind.TransposedConv2dBiasRelu, + library.TensorOperation.AddRelu, + library.MemoryDataOperation.MemorySet, + ) + # TransposedConv2d CreateConv2dBwdOperator( manifest, @@ -2449,37 +2839,37 @@ def GenerateTensorOp(manifest): # GemmRCR CreateGemmRCROperator(manifest) # GemmRCRBias - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.Add) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.Add) # GemmRCRBiasRelu - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddRelu) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddRelu) # GemmRCRBiasTanh - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddTanh) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddTanh) # GemmRCRBiasTanh - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddFastGelu) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddFastGelu) # GemmRCRBiasHardswish - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddHardswish) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddHardswish) # GemmRCRBiasSwish - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddSwish) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddSwish) # GemmRCRBiasSigmoid - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddSigmoid) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddSigmoid) # GemmRCRBiasAdd - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddAdd) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddAdd) # GemmRCRBiasMul - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddMul) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddMul) # GemmRCRBiasMul - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddMulTanh) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddMulTanh) # GemmRCRBiasAddRelu - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddAddRelu) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddAddRelu) # GemmRCRBiasAddAddRelu - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddAddAdd) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddAddAdd) # GemmRCRBiasAddAddRelu - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddAddAddRelu) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddAddAddRelu) # GemmRCRBiasSigmoidMul - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddSigmoidMul) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddSigmoidMul) # GemmRCRBiasSigmoidMulTanh - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddSigmoidMulTanh) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddSigmoidMulTanh) # GemmRCRBiasMulAdd - CreateGemmRCRBilinearOperator(manifest, library.TensorOperation.AddMulAdd) + CreateGemmRCRBillinearOperator(manifest, library.TensorOperation.AddMulAdd) # BmmRCR CreateBmmRCROperator(manifest) # BmmRRR @@ -2527,3 +2917,7 @@ def GenerateGFX908(manifest, rocm_version): def GenerateGFX90A(manifest, rocm_version): GenerateTensorOp(manifest) + + +def GenerateGFX1100(manifest, rocm_version): + GenerateTensorOp(manifest) diff --git a/python/aitemplate/utils/mk_ck_lib/library.py b/python/aitemplate/utils/mk_ck_lib/library.py index 4b6a357b9..935554152 100644 --- a/python/aitemplate/utils/mk_ck_lib/library.py +++ b/python/aitemplate/utils/mk_ck_lib/library.py @@ -229,7 +229,8 @@ class Conv2dKind(enum.Enum): Conv2dBiasRelu = auto() Conv2dBiasReluAdd = auto() Conv2dBiasSigmoid = auto() - GroupConv2dBiasRelu = auto() + GroupConv2dBiasReluXdlops = auto() + GroupConv2dBiasReluWmma = auto() TransposedConv2d = auto() TransposedConv2dBiasRelu = auto() @@ -239,7 +240,8 @@ class Conv2dKind(enum.Enum): Conv2dKind.Conv2dBiasRelu: "conv2d_bias_relu", Conv2dKind.Conv2dBiasReluAdd: "conv2d_bias_relu_add", Conv2dKind.Conv2dBiasSigmoid: "conv2d_bias_sigmoid", - Conv2dKind.GroupConv2dBiasRelu: "group_conv2d_bias_relu", + Conv2dKind.GroupConv2dBiasReluXdlops: "group_conv2d_bias_relu", + Conv2dKind.GroupConv2dBiasReluWmma: "group_conv2d_bias_relu", Conv2dKind.TransposedConv2d: "transposed_conv2d", Conv2dKind.TransposedConv2dBiasRelu: "transposed_conv2d_bias_relu", } diff --git a/python/rebuild.sh b/python/rebuild.sh new file mode 100644 index 000000000..39624f95e --- /dev/null +++ b/python/rebuild.sh @@ -0,0 +1,4 @@ +rm -rf /root/.aitemplate/rocm.db +rm -rf tmp +python3 setup.py bdist_wheel +pip install dist/*.whl --force-reinstall \ No newline at end of file diff --git a/static/include/cuda_device_functions.h b/static/include/cuda_device_functions.h index 4d2c3f463..a535b7e40 100644 --- a/static/include/cuda_device_functions.h +++ b/static/include/cuda_device_functions.h @@ -17,14 +17,7 @@ #include #include -#include "cutlass/conv/conv2d_problem_size.h" -#include "cutlass/conv/device/implicit_gemm_convolution.h" -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/cutlass.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/reference/host/tensor_fill.h" - -#include +#include namespace ait { diff --git a/static/include/cuda_includes.h b/static/include/cuda_includes.h new file mode 100644 index 000000000..33e3711b3 --- /dev/null +++ b/static/include/cuda_includes.h @@ -0,0 +1,6 @@ +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/cutlass.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" \ No newline at end of file diff --git a/static/include/rocm_device_functions.h b/static/include/rocm_device_functions.h index 18d3aa297..3348f5e4a 100644 --- a/static/include/rocm_device_functions.h +++ b/static/include/rocm_device_functions.h @@ -20,11 +20,7 @@ #include #include #include -#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "include/ck/utility/print.hpp" -#include "library/include/ck/library/utility/device_memory.hpp" -#include "library/include/ck/library/utility/host_tensor.hpp" -#include "library/include/ck/library/utility/host_tensor_generator.hpp" +#include namespace ait { diff --git a/static/include/rocm_includes.h b/static/include/rocm_includes.h new file mode 100644 index 000000000..84dc26d73 --- /dev/null +++ b/static/include/rocm_includes.h @@ -0,0 +1,5 @@ +#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/print.hpp" +#include "library/include/ck/library/utility/device_memory.hpp" +#include "library/include/ck/library/utility/host_tensor.hpp" +#include "library/include/ck/library/utility/host_tensor_generator.hpp" \ No newline at end of file diff --git a/tests/unittest/amd_sd_navi/test_unet_conv.py b/tests/unittest/amd_sd_navi/test_unet_conv.py new file mode 100644 index 000000000..49dcedadd --- /dev/null +++ b/tests/unittest/amd_sd_navi/test_unet_conv.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.frontend import IntImm, Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + filter_test_cases_by_params, + get_random_torch_tensor, + TestEnv, +) + +from parameterized import parameterized + + +class ConvBiasTestCase(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + torch.manual_seed(1) + + def _test_conv_bias( + self, + batch=1, + input_dim_x=64, + input_dim_y=64, + weight_dim_x=3, + weight_dim_y=3, + input_channels=320, + output_channels=4, + copy_op=False, + test_name="conv2d_bias", + dtype="float16", + ): + target = detect_target() + X = Tensor( + shape=[IntImm(batch), IntImm(input_dim_x), IntImm(input_dim_y), IntImm(input_channels)], + dtype=dtype, + name="input_0", + is_input=True, + ) + W = Tensor( + shape=[IntImm(output_channels), IntImm(weight_dim_x), IntImm(weight_dim_y), IntImm(input_channels)], + dtype=dtype, + name="input_1", + is_input=True, + ) + B = Tensor( + shape=[IntImm(output_channels)], + dtype=dtype, + name="input_2", + is_input=True, + ) + OP = ops.conv2d_bias(stride=1, pad=1, dilate=1) + if copy_op: + OP = ops.conv2d_bias(**OP._get_op_attributes()) + Y = OP(X, W, B) + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + module = compile_model(Y, target, "./tmp", test_name) + + X_pt = get_random_torch_tensor([batch, input_channels, input_dim_x, input_dim_y], dtype=dtype) + W_pt = get_random_torch_tensor([output_channels, input_channels, weight_dim_x, weight_dim_y], dtype=dtype) + B_pt = get_random_torch_tensor([1, output_channels, 1, 1], dtype=dtype) + Y_pt = torch.nn.functional.conv2d(X_pt.float(), W_pt.float(), padding=1).to( + dtype=X_pt.dtype + ) + Y_pt = Y_pt + B_pt + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + y = torch.empty_like(Y_pt).permute((0, 2, 3, 1)).contiguous() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute((0, 3, 1, 2)) + if target.name() == "cuda": + if dtype == "float32": + torch.testing.assert_close(Y_pt, y_transpose, atol=5e-2, rtol=1e-2) + else: + torch.testing.assert_close(Y_pt, y_transpose, atol=1e-2, rtol=1e-2) + else: + print(test_name, " Running with ROCm") + torch.testing.assert_close(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1) + + @parameterized.expand( + filter_test_cases_by_params( + { + # TestEnv.CUDA_LESS_THAN_SM80: [("float16")], + # TestEnv.CUDA_SM80: [("float32"), ("bfloat16")], + TestEnv.ROCM: [("float16")], + } + ) + ) + def test_conv2d_bias(self, dtype): + # unet model test + unet_model_conv = [ + [2, 64 ,64 ,3, 3, 4, 320], + [2, 64 ,64 ,3, 3, 320, 320], + [2, 32 ,32 ,3, 3, 320, 640], + [2, 32 ,32 ,3, 3, 640, 640], + [2, 16 ,16 ,3, 3, 640, 1280], + [2, 16 ,16 ,3, 3, 1280, 1280], + [2, 8 ,8 ,3, 3, 1280, 1280], + [2, 8 ,8 ,3, 3, 2560, 1280], + [2, 16 ,16 ,3, 3, 2560, 1280], + [2, 16 ,16 ,3, 3, 1920, 1280], + [2, 32 ,32 ,3, 3, 1280, 1280], + [2, 32 ,32 ,3, 3, 1920, 640], + [2, 32 ,32 ,3, 3, 1280, 640], + [2, 32 ,32 ,3, 3, 960, 640], + [2, 64 ,64 ,3, 3, 640, 320], + [2, 64 ,64 ,3, 3, 640, 640], + [2, 64 ,64 ,3, 3, 960, 320], + [2, 64 ,64 ,3, 3, 320, 4], + ] + test_unet_conv_cnt = 0 + for configs in unet_model_conv: + test_unet_conv_cnt +=1 + self._test_conv_bias( + batch=configs[0], + input_dim_x=configs[1], + input_dim_y=configs[2], + weight_dim_x=configs[3], + weight_dim_y=configs[4], + input_channels=configs[5], + output_channels=configs[6], + copy_op=False, + test_name="static", + dtype=dtype, + ) + + # self._test_conv_bias( + # batch=configs[0], + # input_dim_x=configs[1], + # input_dim_y=configs[2], + # weight_dim_x=configs[3], + # weight_dim_y=configs[4], + # input_channels=configs[5], + # output_channels=configs[6], + # copy_op=True, + # test_name=f"conv2d_bias_{dtype}_{test_unet_conv_cnt}_copy_op", + # dtype=dtype, + # ) + + +if __name__ == "__main__": + torch.manual_seed(0) + unittest.main() diff --git a/tests/unittest/amd_sd_navi/test_unet_gemm.py b/tests/unittest/amd_sd_navi/test_unet_gemm.py new file mode 100644 index 000000000..952910fc3 --- /dev/null +++ b/tests/unittest/amd_sd_navi/test_unet_gemm.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.compiler.base import IntImm +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + filter_test_cases_by_test_env, + get_random_torch_tensor, + get_torch_empty_tensor, +) +from aitemplate.utils import shape_utils + + +_TOLERANCE_LIMITS = { + "float16": {"atol": 1e-1, "rtol": 1e-1}, + "float32": {"atol": 1e-1, "rtol": 1e-1}, + "bfloat16": {"atol": 3e-1, "rtol": 3e-1}, +} + + +class GEMMBiasTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(GEMMBiasTestCase, self).__init__(*args, **kwargs) + self._test_id = 0 + + def _test_rcr(self, Ms, N, K, test_name, dtype="float16"): + target = detect_target() + tolerance_limits = _TOLERANCE_LIMITS[dtype] + MDim = shape_utils.gen_int_var_min_max(Ms, name="m") + X = Tensor(shape=[MDim, IntImm(K)], dtype=dtype, name="input_0", is_input=True) + W = Tensor( + shape=[IntImm(N), IntImm(K)], dtype=dtype, name="input_1", is_input=True + ) + B = Tensor(shape=[IntImm(N)], dtype=dtype, name="input_2", is_input=True) + OP = ops.gemm_rcr_bias() + Y = OP(X, W, B) + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + module = compile_model( + Y, target, "./tmp", f"gemm_rcr_bias_{test_name}_{self._test_id}" + ) + self._test_id += 1 + + for M in Ms: + X_pt = get_random_torch_tensor([M, K], dtype) + W_pt = get_random_torch_tensor([N, K], dtype) + B_pt = get_random_torch_tensor([N], dtype) + Y_pt = torch.nn.functional.linear(X_pt, W_pt, bias=B_pt) + + y = get_torch_empty_tensor([M, N], dtype) + module.run_with_tensors( + {"input_0": X_pt, "input_1": W_pt, "input_2": B_pt}, + [y], + ) + if X_pt.nelement() == 0 or W_pt.nelement() == 0: + pass + else: + torch.testing.assert_close(Y_pt, y, **tolerance_limits) + + def test_rcr_static_rocm(self): + self._test_rcr([2], N=1280, K=1280, test_name="static") + self._test_rcr([2], N=640, K=1280, test_name="static") + self._test_rcr([2], N=320, K=1280, test_name="static") + + self._test_rcr([64], N=320, K=320, test_name="static") + self._test_rcr([8192], N=320, K=1280, test_name="static") + self._test_rcr([8192], N=1280, K=320, test_name="static") + + self._test_rcr([8192], N=320, K=320, test_name="static") + self._test_rcr([32], N=640, K=640, test_name="static") + self._test_rcr([2048], N=640, K=640, test_name="static") + + self._test_rcr([2048], N=640, K=2560, test_name="static") + self._test_rcr([16], N=1280, K=1280, test_name="static") + self._test_rcr([512], N=1280, K=1280, test_name="static") + + self._test_rcr([512], N=1280, K=5120, test_name="static") + + self._test_rcr([8], N=1280, K=1280, test_name="static") + self._test_rcr([128], N=1280, K=1280, test_name="static") + self._test_rcr([128], N=5120, K=5120, test_name="static") + + +filter_test_cases_by_test_env(GEMMBiasTestCase) + + +if __name__ == "__main__": + torch.manual_seed(0) + unittest.main() diff --git a/tests/unittest/amd_sd_navi/test_unet_groupnorm.py b/tests/unittest/amd_sd_navi/test_unet_groupnorm.py new file mode 100644 index 000000000..238753a34 --- /dev/null +++ b/tests/unittest/amd_sd_navi/test_unet_groupnorm.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Unittests for group norm Operator. +""" +import logging +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import get_random_torch_tensor + + +_LOGGER = logging.getLogger(__name__) + + +@unittest.skipIf(detect_target()._arch == "75", "Skip GN on sm75.") +class GroupnormTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(GroupnormTestCase, self).__init__(*args, **kwargs) + self.test_count = 0 + + def _test_groupnorm( + self, + x_shape=(4, 14, 14, 1024), + num_groups=32, + gamma_is_none=False, + beta_is_none=False, + use_size_op=False, + eps=1e-5, + use_swish=False, + copy_op=False, + atol=1e-2, + rtol=1e-2, + dtype="float16", + ): + test_name = "group_norm_swish" if use_swish else "group_norm" + _LOGGER.info(f"Testing {test_name}: {x_shape}, num_groups: {num_groups}") + num_channels = x_shape[-1] + X1 = Tensor( + shape=x_shape, + dtype=dtype, + name="X", + is_input=True, + ) + X2 = Tensor( + shape=[num_channels], + dtype=dtype, + name="gamma", + is_input=True, + ) + X3 = Tensor( + shape=[num_channels], + dtype=dtype, + name="beta", + is_input=True, + ) + + op_name = "group_norm_swish" if use_swish else "group_norm" + OP = getattr(ops, op_name)(num_groups, num_channels) + if copy_op: + OP = getattr(ops, op_name)(**OP._get_op_attributes()) + X4 = OP(X1, X2, X3, eps) + X4._attrs["is_output"] = True + X4._attrs["name"] = "output" + + target = detect_target() + dll_name = f"test_{self.test_count}.so" + module = compile_model(X4, target, "./tmp", op_name, dll_name=dll_name) + + x1_nhwc_pt = get_random_torch_tensor(x_shape, dtype) + x1_nchw_pt = x1_nhwc_pt.permute(0, 3, 1, 2).contiguous() + gamma_pt = get_random_torch_tensor((num_channels,), dtype) + beta_pt = torch.randn_like(gamma_pt) + + x4_pt = torch.nn.functional.group_norm( + x1_nchw_pt, num_groups, gamma_pt, beta_pt, eps=eps + ) + if use_swish: + x4_pt = torch.nn.SiLU()(x4_pt) + + inputs = {"X": x1_nhwc_pt} + inputs["gamma"] = gamma_pt + inputs["beta"] = beta_pt + x4 = torch.empty_like(x1_nhwc_pt) + module.run_with_tensors(inputs, [x4]) + + torch.testing.assert_close( + x4, x4_pt.permute(0, 2, 3, 1).contiguous(), atol=atol, rtol=rtol + ) + self.test_count += 1 + + def test_groupnorm_float16(self): + self._test_groupnorm() + self._test_groupnorm(x_shape=[7, 13, 9, 12], num_groups=4, eps=1e-5) + self._test_groupnorm(x_shape=[1, 16, 16, 8192], num_groups=32, eps=1e-3) + self._test_groupnorm(x_shape=[3, 64, 64, 128], num_groups=16, eps=1e-5) + self._test_groupnorm(x_shape=[3, 33, 64, 120], num_groups=10, eps=1e-5) + self._test_groupnorm(x_shape=[8, 34, 10, 72], num_groups=6, eps=1e-5) + self._test_groupnorm(x_shape=[1, 8, 1, 64], num_groups=32, eps=1e-5) + self._test_groupnorm(x_shape=[1, 8, 1, 4], num_groups=2, eps=1e-5, copy_op=True) + + def test_groupnorm_swish(self): + # unet model groupnorm+swish + unet_shapes = [ + (2, 64, 64, 320), + (2, 32, 32, 320), + (2, 32, 32, 640), + (2, 16, 16, 640), + (2, 16, 16, 1280), + (2, 8, 8, 1280), + (2, 8, 8, 2560), + (2, 16, 16, 2560), + (2, 32, 32, 1920), + (2, 32, 32, 1280), + (2, 32, 32, 960), + (2, 64, 64, 960), + (2, 64, 64, 640), + ] + for shape in unet_shapes: + self._test_groupnorm(x_shape=shape, num_groups=32, eps=1e-5, use_swish=True) + + @unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm") + def test_groupnorm_float32(self): + # H % 8 != 0 + self._test_groupnorm( + x_shape=[7, 13, 9, 12], + num_groups=4, + eps=1e-5, + dtype="float32", + use_swish=True, + ) + # H % 8 == 0 + self._test_groupnorm( + x_shape=[2, 16, 16, 640], + num_groups=32, + eps=1e-5, + dtype="float32", + use_swish=True, + ) + + @unittest.skipIf( + detect_target().name() == "cuda" and int(detect_target()._arch) < 80, + "bf16 is supported with CUDA sm80+", + ) + @unittest.skipIf(detect_target().name() == "rocm", "bf16 not supported in ROCm") + def test_groupnorm_bfloat16(self): + # H % 8 != 0 + self._test_groupnorm( + x_shape=[7, 13, 9, 12], + num_groups=4, + eps=1e-5, + atol=1e-1, + rtol=1e-1, + dtype="bfloat16", + use_swish=True, + ) + # H % 8 == 0 + self._test_groupnorm( + x_shape=[2, 16, 16, 640], + num_groups=32, + eps=1e-5, + atol=1e-1, + rtol=1e-1, + dtype="bfloat16", + use_swish=True, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/amd_sd_navi/test_unet_mha.py b/tests/unittest/amd_sd_navi/test_unet_mha.py new file mode 100644 index 000000000..3cf8b6b86 --- /dev/null +++ b/tests/unittest/amd_sd_navi/test_unet_mha.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# BMM + Softmax + BMM +# (B, M, K) * (B, N, K) = (B, M, N) #RCR +# softmax on dim N (B, M, N) +# (B, M, N) * (B, N, O) = (B, M, O) #RRR +import itertools +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import filter_test_cases_by_test_env +from aitemplate.utils import shape_utils + + +def build_causal_attention_mask(bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class BMMSoftmaxBMMTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(BMMSoftmaxBMMTestCase, self).__init__(*args, **kwargs) + self.test_count = 0 + + def _test_bmm_permute( + self, + bs, + ms, + N, + K, + D, + head_dim=64, + num_heads=12, + causal=False, + test_name="ck_attn", + copy_op=False, + ): + target = detect_target() + batch_dim = shape_utils.gen_int_var_min_max(bs, name="batch_size") + m_dim = shape_utils.gen_int_var_min_max(ms, name="m") + X = Tensor( + shape=[batch_dim, m_dim, K], dtype="float16", name="input_0", is_input=True + ) + B0 = Tensor( + shape=[batch_dim, N, K], dtype="float16", name="input_1", is_input=True + ) + B1 = Tensor( + shape=[batch_dim, N, D], dtype="float16", name="input_2", is_input=True + ) + + scale = head_dim**-0.5 + + OP = ops.bmm_softmax_bmm_permute(shape=(num_heads,), scale=scale, causal=causal) + if copy_op: + OP = ops.bmm_softmax_bmm_permute(**OP._get_op_attributes()) + Y = OP(X, B0, B1) + + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + dll_name = f"test_{self.test_count}.so" + module = compile_model( + Y, target, "./tmp", f"bmm_{test_name}_permute", dll_name=dll_name + ) + + for b, m in itertools.product(bs, ms): + X_pt = torch.randn(b, m, K).cuda().half() # Q + W_pt = torch.randn(b, N, K).cuda().half() # K + B1_pt = torch.randn(b, N, D).cuda().half() # V + + attn = (X_pt @ W_pt.transpose(-2, -1)) * scale + + if causal: + bsz = 1 + tgt_len = m + src_len = N + causal_attention_mask = build_causal_attention_mask( + bsz, m, attn.dtype + ).to(attn.device) + attn_weights = attn.reshape(bsz, num_heads, m, N) + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, num_heads, tgt_len, src_len) + + causal_attention_mask + ) + attn = attn_weights.view(bsz * num_heads, tgt_len, src_len) + + attn = attn.softmax(dim=-1) + Y_l = attn @ B1_pt + Y_r = Y_l.reshape(b // num_heads, num_heads, m, D) + Y2_pt = torch.permute(Y_r, [0, 2, 1, 3]) + + y = torch.empty([b // num_heads, m, num_heads, D]).cuda().half() + module.run_with_tensors([X_pt, W_pt, B1_pt], [y]) + if X_pt.nelement() == 0 or Y2_pt.nelement() == 0: + pass + else: + self.assertTrue(torch.allclose(Y2_pt, y, atol=1e-1, rtol=1e-1)) + + # benchmark + # time_per_iter_ms, time_std, _ = module.benchmark_with_tensors( + # [X_pt, W_pt, B1_pt], [y], count=200, repeat=2 + # ) + + def _test_b2b( + self, bs, ms, N, K, D, head_dim=64, test_name="ck_attn", copy_op=False + ): + target = detect_target() + batch_dim = shape_utils.gen_int_var_min_max(bs, name="batch_size") + m_dim = shape_utils.gen_int_var_min_max(ms, name="m") + X = Tensor( + shape=[batch_dim, m_dim, K], dtype="float16", name="input_0", is_input=True + ) + B0 = Tensor( + shape=[batch_dim, N, K], dtype="float16", name="input_1", is_input=True + ) + B1 = Tensor( + shape=[batch_dim, N, D], dtype="float16", name="input_2", is_input=True + ) + + scale = head_dim**-0.5 + + OP = ops.bmm_softmax_bmm(scale=scale) + if copy_op: + OP = ops.bmm_softmax_bmm(OP._get_op_attributes()) + Y = OP(X, B0, B1) + + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + dll_name = f"test_{self.test_count}.so" + module = compile_model( + Y, target, "./tmp", f"bmm_{test_name}_permute", dll_name=dll_name + ) + + for b, m in itertools.product(bs, ms): + X_pt = torch.randn(b, m, K).cuda().half() # Q + W_pt = torch.randn(b, N, K).cuda().half() # K + B1_pt = torch.randn(b, N, D).cuda().half() # V + + attn = (X_pt @ W_pt.transpose(-2, -1)) * scale + attn = attn.softmax(dim=-1) + Y2_pt = attn @ B1_pt + + y = torch.empty([b, m, D]).cuda().half() + module.run_with_tensors([X_pt, W_pt, B1_pt], [y]) + if X_pt.nelement() == 0 or Y2_pt.nelement() == 0: + pass + else: + self.assertTrue(torch.allclose(Y2_pt, y, atol=1e-1, rtol=1e-1)) + + def test_rcr_rocm(self): + # FIXME: re-enable it after we fix the missing parameter for bmm_softmax_bmm + self._test_bmm_permute([10], [4096], N=4096, K=64, D=64, head_dim=64, num_heads=5,test_name="static") + self._test_bmm_permute([10], [4096], N=64, K=64, D=64, head_dim=64, num_heads=5,test_name="static") + self._test_bmm_permute([20], [1024], N=1024, K=64, D=64, head_dim=64, num_heads=5,test_name="static") + self._test_bmm_permute([20], [1024], N=64, K=64, D=64, head_dim=64, num_heads=5,test_name="static") + self._test_bmm_permute([40], [256], N=256, K=64, D=64, head_dim=64, num_heads=5,test_name="static") + self._test_bmm_permute([40], [256], N=64, K=64, D=64, head_dim=64, num_heads=5,test_name="static") + self._test_bmm_permute([40], [64], N=64, K=64, D=64, head_dim=64, num_heads=5,test_name="static") + + + + +filter_test_cases_by_test_env(BMMSoftmaxBMMTestCase) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/amd_sd_navi/test_vae_conv.py b/tests/unittest/amd_sd_navi/test_vae_conv.py new file mode 100644 index 000000000..88af798eb --- /dev/null +++ b/tests/unittest/amd_sd_navi/test_vae_conv.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.frontend import IntImm, Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + filter_test_cases_by_params, + get_random_torch_tensor, + TestEnv, +) + +from parameterized import parameterized + + +class ConvBiasTestCase(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + torch.manual_seed(1) + + def _test_conv_bias( + self, + batch=1, + input_dim_x=64, + input_dim_y=64, + weight_dim_x=3, + weight_dim_y=3, + input_channels=320, + output_channels=4, + copy_op=False, + test_name="conv2d_bias", + dtype="float16", + ): + target = detect_target() + X = Tensor( + shape=[IntImm(batch), IntImm(input_dim_x), IntImm(input_dim_y), IntImm(input_channels)], + dtype=dtype, + name="input_0", + is_input=True, + ) + W = Tensor( + shape=[IntImm(output_channels), IntImm(weight_dim_x), IntImm(weight_dim_y), IntImm(input_channels)], + dtype=dtype, + name="input_1", + is_input=True, + ) + B = Tensor( + shape=[IntImm(output_channels)], + dtype=dtype, + name="input_2", + is_input=True, + ) + OP = ops.conv2d_bias(stride=1, pad=1, dilate=1) + if copy_op: + OP = ops.conv2d_bias(**OP._get_op_attributes()) + Y = OP(X, W, B) + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + module = compile_model(Y, target, "./tmp", test_name) + + X_pt = get_random_torch_tensor([batch, input_channels, input_dim_x, input_dim_y], dtype=dtype) + W_pt = get_random_torch_tensor([output_channels, input_channels, weight_dim_x, weight_dim_y], dtype=dtype) + B_pt = get_random_torch_tensor([1, output_channels, 1, 1], dtype=dtype) + Y_pt = torch.nn.functional.conv2d(X_pt.float(), W_pt.float(), padding=1).to( + dtype=X_pt.dtype + ) + Y_pt = Y_pt + B_pt + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + y = torch.empty_like(Y_pt).permute((0, 2, 3, 1)).contiguous() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute((0, 3, 1, 2)) + if target.name() == "cuda": + if dtype == "float32": + torch.testing.assert_close(Y_pt, y_transpose, atol=5e-2, rtol=1e-2) + else: + torch.testing.assert_close(Y_pt, y_transpose, atol=1e-2, rtol=1e-2) + else: + torch.testing.assert_close(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1) + + @parameterized.expand( + filter_test_cases_by_params( + { + # TestEnv.CUDA_LESS_THAN_SM80: [("float16")], + # TestEnv.CUDA_SM80: [("float32"), ("bfloat16")], + TestEnv.ROCM: [("float16")], + } + ) + ) + def test_conv2d_bias(self, dtype): + # vae model test + vae_model_conv = [ + [64 ,64 ,1, 1, 4, 4], + [64 ,64 ,3, 3, 512, 4], + [64 ,64 ,3, 3, 512, 512], + [128 ,128 ,3, 3, 512, 512], + [256 ,256 ,3, 3, 512, 512], + [256 ,256 ,3, 3, 256, 512], + [256 ,256 ,3, 3, 256, 256], + [256 ,256 ,3, 3, 512, 256], + [512 ,512 ,3, 3, 256, 256], + [512 ,512 ,3, 3, 128, 256], + [512 ,512 ,3, 3, 128, 128], + [512 ,512 ,3, 3, 128, 3], + ] + test_conv_cnt = 0 + for configs in vae_model_conv: + self._test_conv_bias( + input_dim_x=configs[0], + input_dim_y=configs[1], + weight_dim_x=configs[2], + weight_dim_y=configs[3], + input_channels=configs[4], + output_channels=configs[5], + copy_op=False, + test_name=f"conv2d_bias_{dtype}_{test_conv_cnt}", + dtype=dtype, + ) + + self._test_conv_bias( + input_dim_x=configs[0], + input_dim_y=configs[1], + weight_dim_x=configs[2], + weight_dim_y=configs[3], + input_channels=configs[4], + output_channels=configs[5], + copy_op=True, + test_name=f"conv2d_bias_{dtype}_{test_conv_cnt}_copy_op", + dtype=dtype, + ) + + +if __name__ == "__main__": + torch.manual_seed(0) + unittest.main() diff --git a/tests/unittest/amd_sd_navi/test_vae_gemm.py b/tests/unittest/amd_sd_navi/test_vae_gemm.py new file mode 100644 index 000000000..58be7572f --- /dev/null +++ b/tests/unittest/amd_sd_navi/test_vae_gemm.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.compiler.base import IntImm +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import ( + filter_test_cases_by_test_env, + get_random_torch_tensor, + get_torch_empty_tensor, +) +from aitemplate.utils import shape_utils + + +_TOLERANCE_LIMITS = { + "float16": {"atol": 1e-1, "rtol": 1e-1}, + "float32": {"atol": 1e-1, "rtol": 1e-1}, + "bfloat16": {"atol": 3e-1, "rtol": 3e-1}, +} + + +class GEMMBiasTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(GEMMBiasTestCase, self).__init__(*args, **kwargs) + self._test_id = 0 + + def _test_rcr(self, Ms, N, K, test_name, dtype="float16"): + target = detect_target() + tolerance_limits = _TOLERANCE_LIMITS[dtype] + MDim = shape_utils.gen_int_var_min_max(Ms, name="m") + X = Tensor(shape=[MDim, IntImm(K)], dtype=dtype, name="input_0", is_input=True) + W = Tensor( + shape=[IntImm(N), IntImm(K)], dtype=dtype, name="input_1", is_input=True + ) + B = Tensor(shape=[IntImm(N)], dtype=dtype, name="input_2", is_input=True) + OP = ops.gemm_rcr_bias() + Y = OP(X, W, B) + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + module = compile_model( + Y, target, "./tmp", f"gemm_rcr_bias_{test_name}_{self._test_id}" + ) + self._test_id += 1 + + for M in Ms: + X_pt = get_random_torch_tensor([M, K], dtype) + W_pt = get_random_torch_tensor([N, K], dtype) + B_pt = get_random_torch_tensor([N], dtype) + Y_pt = torch.nn.functional.linear(X_pt, W_pt, bias=B_pt) + + y = get_torch_empty_tensor([M, N], dtype) + module.run_with_tensors( + {"input_0": X_pt, "input_1": W_pt, "input_2": B_pt}, + [y], + ) + if X_pt.nelement() == 0 or W_pt.nelement() == 0: + pass + else: + torch.testing.assert_close(Y_pt, y, **tolerance_limits) + + def test_rcr_static_rocm(self): + self._test_rcr([64], N=1536, K=512, test_name="static") + self._test_rcr([64], N=512, K=512, test_name="static") + + +filter_test_cases_by_test_env(GEMMBiasTestCase) + + +if __name__ == "__main__": + torch.manual_seed(0) + unittest.main() diff --git a/tests/unittest/amd_sd_navi/test_vae_groupnorm.py b/tests/unittest/amd_sd_navi/test_vae_groupnorm.py new file mode 100644 index 000000000..0a44d8555 --- /dev/null +++ b/tests/unittest/amd_sd_navi/test_vae_groupnorm.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Unittests for group norm Operator. +""" +import logging +import unittest + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.frontend import Tensor +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import get_random_torch_tensor + + +_LOGGER = logging.getLogger(__name__) + + +@unittest.skipIf(detect_target()._arch == "75", "Skip GN on sm75.") +class GroupnormTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(GroupnormTestCase, self).__init__(*args, **kwargs) + self.test_count = 0 + + def _test_groupnorm( + self, + x_shape=(4, 14, 14, 1024), + num_groups=32, + gamma_is_none=False, + beta_is_none=False, + use_size_op=False, + eps=1e-5, + use_swish=False, + copy_op=False, + atol=1e-2, + rtol=1e-2, + dtype="float16", + ): + test_name = "group_norm_swish" if use_swish else "group_norm" + _LOGGER.info(f"Testing {test_name}: {x_shape}, num_groups: {num_groups}") + num_channels = x_shape[-1] + X1 = Tensor( + shape=x_shape, + dtype=dtype, + name="X", + is_input=True, + ) + X2 = Tensor( + shape=[num_channels], + dtype=dtype, + name="gamma", + is_input=True, + ) + X3 = Tensor( + shape=[num_channels], + dtype=dtype, + name="beta", + is_input=True, + ) + + op_name = "group_norm_swish" if use_swish else "group_norm" + OP = getattr(ops, op_name)(num_groups, num_channels) + if copy_op: + OP = getattr(ops, op_name)(**OP._get_op_attributes()) + X4 = OP(X1, X2, X3, eps) + X4._attrs["is_output"] = True + X4._attrs["name"] = "output" + + target = detect_target() + dll_name = f"test_{self.test_count}.so" + module = compile_model(X4, target, "./tmp", op_name, dll_name=dll_name) + + x1_nhwc_pt = get_random_torch_tensor(x_shape, dtype) + x1_nchw_pt = x1_nhwc_pt.permute(0, 3, 1, 2).contiguous() + gamma_pt = get_random_torch_tensor((num_channels,), dtype) + beta_pt = torch.randn_like(gamma_pt) + + x4_pt = torch.nn.functional.group_norm( + x1_nchw_pt, num_groups, gamma_pt, beta_pt, eps=eps + ) + if use_swish: + x4_pt = torch.nn.SiLU()(x4_pt) + + inputs = {"X": x1_nhwc_pt} + inputs["gamma"] = gamma_pt + inputs["beta"] = beta_pt + x4 = torch.empty_like(x1_nhwc_pt) + module.run_with_tensors(inputs, [x4]) + + torch.testing.assert_close( + x4, x4_pt.permute(0, 2, 3, 1).contiguous(), atol=atol, rtol=rtol + ) + self.test_count += 1 + + def test_groupnorm_float16(self): + self._test_groupnorm() + self._test_groupnorm(x_shape=[7, 13, 9, 12], num_groups=4, eps=1e-5) + self._test_groupnorm(x_shape=[1, 16, 16, 8192], num_groups=32, eps=1e-3) + self._test_groupnorm(x_shape=[3, 64, 64, 128], num_groups=16, eps=1e-5) + self._test_groupnorm(x_shape=[3, 33, 64, 120], num_groups=10, eps=1e-5) + self._test_groupnorm(x_shape=[8, 34, 10, 72], num_groups=6, eps=1e-5) + self._test_groupnorm(x_shape=[1, 8, 1, 64], num_groups=32, eps=1e-5) + self._test_groupnorm(x_shape=[1, 8, 1, 4], num_groups=2, eps=1e-5, copy_op=True) + + def test_groupnorm_swish(self): + # vae model groupnorm+swish + vae_shapes = [ + (1, 64, 64, 512), + (1, 128, 128, 512), + (1, 256, 256, 512), + (1, 256, 256, 256), + (1, 512, 512, 256), + (1, 512, 512, 128), + ] + for shape in vae_shapes: + self._test_groupnorm(x_shape=shape, num_groups=32, eps=1e-5, use_swish=True) + + @unittest.skipIf(detect_target().name() == "rocm", "fp32 not supported in ROCm") + def test_groupnorm_float32(self): + # H % 8 != 0 + self._test_groupnorm( + x_shape=[7, 13, 9, 12], + num_groups=4, + eps=1e-5, + dtype="float32", + use_swish=True, + ) + # H % 8 == 0 + self._test_groupnorm( + x_shape=[2, 16, 16, 640], + num_groups=32, + eps=1e-5, + dtype="float32", + use_swish=True, + ) + + @unittest.skipIf( + detect_target().name() == "cuda" and int(detect_target()._arch) < 80, + "bf16 is supported with CUDA sm80+", + ) + @unittest.skipIf(detect_target().name() == "rocm", "bf16 not supported in ROCm") + def test_groupnorm_bfloat16(self): + # H % 8 != 0 + self._test_groupnorm( + x_shape=[7, 13, 9, 12], + num_groups=4, + eps=1e-5, + atol=1e-1, + rtol=1e-1, + dtype="bfloat16", + use_swish=True, + ) + # H % 8 == 0 + self._test_groupnorm( + x_shape=[2, 16, 16, 640], + num_groups=32, + eps=1e-5, + atol=1e-1, + rtol=1e-1, + dtype="bfloat16", + use_swish=True, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/backend/test_profiler.py b/tests/unittest/backend/test_profiler.py index 438df946d..d6564d827 100644 --- a/tests/unittest/backend/test_profiler.py +++ b/tests/unittest/backend/test_profiler.py @@ -63,7 +63,6 @@ def test_profiler_runner(self): cmds=["sleep", f"{sleep_for}"], process_result_callback=delegate_cb_wrapper(i, sleep_for), ) - pr.join() diff --git a/tests/unittest/ops/test_bmm_softmax_bmm.py b/tests/unittest/ops/test_bmm_softmax_bmm.py index 077480c51..935ee868d 100644 --- a/tests/unittest/ops/test_bmm_softmax_bmm.py +++ b/tests/unittest/ops/test_bmm_softmax_bmm.py @@ -190,20 +190,20 @@ def test_rcr_rocm(self): self._test_bmm_permute( [16], [4096], N=64, K=40, D=40, num_heads=8, test_name="static" ) - self._test_bmm_permute( - [12], [64], N=64, K=64, D=64, num_heads=12, causal=True, test_name="static" - ) - self._test_bmm_permute( - [12], - [64], - N=64, - K=64, - D=64, - num_heads=12, - causal=True, - test_name="static_copy_op", - copy_op=True, - ) + # self._test_bmm_permute( + # [12], [64], N=64, K=64, D=64, num_heads=12, causal=True, test_name="static" + # ) + # self._test_bmm_permute( + # [12], + # [64], + # N=64, + # K=64, + # D=64, + # num_heads=12, + # causal=True, + # test_name="static_copy_op", + # copy_op=True, + # ) filter_test_cases_by_test_env(BMMSoftmaxBMMTestCase) diff --git a/tests/unittest/ops/test_conv.py b/tests/unittest/ops/test_conv.py index db5621174..fcc1d033d 100644 --- a/tests/unittest/ops/test_conv.py +++ b/tests/unittest/ops/test_conv.py @@ -17,37 +17,21 @@ import torch from aitemplate.compiler import compile_model, ops -from aitemplate.frontend import IntImm, nn, Tensor +from aitemplate.frontend import IntImm, Tensor from aitemplate.testing import detect_target -from aitemplate.testing.test_utils import ( - filter_test_cases_by_params, - get_random_torch_tensor, - TestEnv, -) - -from parameterized import parameterized class ConvTestCase(unittest.TestCase): - def _test_conv( - self, - batch=4, - copy_op=False, - test_name="conv2d", - dtype="float16", - ): + def _test_fp16(self, batch=1, copy_op=False): target = detect_target() X = Tensor( - shape=[IntImm(batch), 28, 28, 128], - dtype=dtype, + shape=[IntImm(batch), 224, 224, 3], + dtype="float16", name="input_0", is_input=True, ) W = Tensor( - shape=[256, 3, 3, 128], - dtype=dtype, - name="input_1", - is_input=True, + shape=[256, 3, 3, 3], dtype="float16", name="input_1", is_input=True ) OP = ops.conv2d(stride=1, pad=1, dilate=1) if copy_op: @@ -55,24 +39,20 @@ def _test_conv( Y = OP(X, W) Y._attrs["name"] = "output_0" Y._attrs["is_output"] = True - module = compile_model(Y, target, "./tmp", test_name) + module = compile_model(Y, target, "./tmp", f"conv2d_{copy_op}") - X_pt = get_random_torch_tensor([batch, 128, 28, 28], dtype=dtype) - W_pt = get_random_torch_tensor([256, 128, 3, 3], dtype=dtype) - Y_pt = torch.nn.functional.conv2d(X_pt.float(), W_pt.float(), padding=1).to( - dtype=X_pt.dtype - ) + X_pt = torch.randn(batch, 3, 224, 224).cuda().half() + W_pt = torch.randn(256, 3, 3, 3).cuda().half() + Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) x = X_pt.permute((0, 2, 3, 1)).contiguous() w = W_pt.permute((0, 2, 3, 1)).contiguous() y = torch.empty_like(Y_pt).permute((0, 2, 3, 1)).contiguous() module.run_with_tensors({"input_0": x, "input_1": w}, [y]) y_transpose = y.permute((0, 3, 1, 2)) if target.name() == "cuda": - if dtype == "float32": - torch.testing.assert_close(Y_pt, y_transpose, atol=1e-1, rtol=1e-1) - else: - torch.testing.assert_close(Y_pt, y_transpose, atol=1e-2, rtol=1e-2) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2)) else: +<<<<<<< HEAD torch.testing.assert_close(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1) @parameterized.expand( @@ -134,57 +114,15 @@ def _test_conv1d(self, dtype, bias): X_pt = get_random_torch_tensor([batch, C_in, L], dtype=dtype) W_pt = get_random_torch_tensor([C_out, C_in, K], dtype=dtype) bias_pt = get_random_torch_tensor([C_out], dtype=dtype) if bias else None +======= + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1)) +>>>>>>> origin/navi3_rel_ver_1.0 - X = Tensor( - shape=[IntImm(batch), L, C_in], - dtype=dtype, - name="input_0", - is_input=True, - ) - mod = nn.Conv1d( - in_channels=C_in, - out_channels=C_out, - kernel_size=K, - stride=stride, - padding=padding, - dilation=dilation, - dtype=dtype, - bias=bias, - ) - - Y = mod(X) - - Y._attrs["name"] = "output_0" - Y._attrs["is_output"] = True - module = compile_model(Y, target, "./tmp", test_name) - module.set_constant_with_tensor( - "conv1d_weight", W_pt.permute((0, 2, 1)).contiguous() - ) - if bias: - module.set_constant_with_tensor("conv1d_bias", bias_pt) - Y_pt = torch.nn.functional.conv1d( - X_pt.float(), - W_pt.float(), - bias=bias_pt.float() if bias else None, - padding=padding, - stride=stride, - dilation=dilation, - ).to(dtype=X_pt.dtype) - - x = X_pt.permute((0, 2, 1)).contiguous() - - y = torch.empty_like(Y_pt).permute((0, 2, 1)).contiguous() - module.run_with_tensors({"input_0": x}, [y]) - y_transpose = y.permute((0, 2, 1)) - if target.name() == "cuda": - if dtype == "float32": - torch.testing.assert_close(Y_pt, y_transpose, atol=1.5e-1, rtol=1e-1) - else: - torch.testing.assert_close(Y_pt, y_transpose, atol=1e-2, rtol=1e-2) - else: - torch.testing.assert_close(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1) + def test_fp16(self): + self._test_fp16() + self._test_fp16(copy_op=True) if __name__ == "__main__": torch.manual_seed(0) - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/unittest/ops/test_conv_bias.py b/tests/unittest/ops/test_conv_bias.py index bb0e774c4..a64c058bc 100644 --- a/tests/unittest/ops/test_conv_bias.py +++ b/tests/unittest/ops/test_conv_bias.py @@ -35,26 +35,32 @@ def setUpClass(cls) -> None: def _test_conv_bias( self, - batch=4, + batch=1, + input_dim_x=64, + input_dim_y=64, + weight_dim_x=3, + weight_dim_y=3, + input_channels=320, + output_channels=4, copy_op=False, test_name="conv2d_bias", dtype="float16", ): target = detect_target() X = Tensor( - shape=[IntImm(batch), 28, 28, 128], + shape=[IntImm(batch), IntImm(input_dim_x), IntImm(input_dim_y), IntImm(input_channels)], dtype=dtype, name="input_0", is_input=True, ) W = Tensor( - shape=[256, 3, 3, 128], + shape=[IntImm(output_channels), IntImm(weight_dim_x), IntImm(weight_dim_y), IntImm(input_channels)], dtype=dtype, name="input_1", is_input=True, ) B = Tensor( - shape=[256], + shape=[IntImm(output_channels)], dtype=dtype, name="input_2", is_input=True, @@ -67,9 +73,9 @@ def _test_conv_bias( Y._attrs["is_output"] = True module = compile_model(Y, target, "./tmp", test_name) - X_pt = get_random_torch_tensor([batch, 128, 28, 28], dtype=dtype) - W_pt = get_random_torch_tensor([256, 128, 3, 3], dtype=dtype) - B_pt = get_random_torch_tensor([1, 256, 1, 1], dtype=dtype) + X_pt = get_random_torch_tensor([batch, input_channels, input_dim_x, input_dim_y], dtype=dtype) + W_pt = get_random_torch_tensor([output_channels, input_channels, weight_dim_x, weight_dim_y], dtype=dtype) + B_pt = get_random_torch_tensor([1, output_channels, 1, 1], dtype=dtype) Y_pt = torch.nn.functional.conv2d(X_pt.float(), W_pt.float(), padding=1).to( dtype=X_pt.dtype ) @@ -91,13 +97,14 @@ def _test_conv_bias( @parameterized.expand( **filter_test_cases_by_params( { - TestEnv.CUDA_LESS_THAN_SM80: [("float16")], - TestEnv.CUDA_SM80: [("float32"), ("bfloat16")], + # TestEnv.CUDA_LESS_THAN_SM80: [("float16")], + # TestEnv.CUDA_SM80: [("float32"), ("bfloat16")], TestEnv.ROCM: [("float16")], } ) ) def test_conv2d_bias(self, dtype): + # default self._test_conv_bias( test_name=f"conv2d_bias_{dtype}", dtype=dtype, @@ -107,6 +114,88 @@ def test_conv2d_bias(self, dtype): test_name=f"conv2d_bias_{dtype}_copy_op", dtype=dtype, ) + # unet model test + # Not implemented yet + # vae_model_conv = [ + # [64 ,64 ,1, 1, 4, 4], + # [64 ,64 ,3, 3, 512, 4], + # [64 ,64 ,3, 3, 512, 512], + # [128 ,128 ,3, 3, 512, 512], + # [256 ,256 ,3, 3, 512, 512], + # [256 ,256 ,3, 3, 256, 512], + # [256 ,256 ,3, 3, 256, 256], + # [256 ,256 ,3, 3, 512, 256], + # [512 ,512 ,3, 3, 256, 256], + # [512 ,512 ,3, 3, 128, 256], + # [512 ,512 ,3, 3, 128, 128], + # [512 ,512 ,3, 3, 128, 3], + # ] + # test_conv_cnt = 0 + # for configs in vae_model_conv: + # self._test_conv_bias( + # input_dim_x=configs[0], + # input_dim_y=configs[1], + # weight_dim_x=configs[2], + # weight_dim_y=configs[3], + # input_channels=configs[4], + # output_channels=configs[5], + # copy_op=False, + # test_name=f"conv2d_bias_{dtype}_{test_conv_cnt}", + # dtype=dtype, + # ) + + # self._test_conv_bias( + # input_dim_x=configs[0], + # input_dim_y=configs[1], + # weight_dim_x=configs[2], + # weight_dim_y=configs[3], + # input_channels=configs[4], + # output_channels=configs[5], + # copy_op=True, + # test_name=f"conv2d_bias_{dtype}_{test_conv_cnt}_copy_op", + # dtype=dtype, + # ) + + # vae model test + vae_model_conv = [ + [64 ,64 ,1, 1, 4, 4], + [64 ,64 ,3, 3, 512, 4], + [64 ,64 ,3, 3, 512, 512], + [128 ,128 ,3, 3, 512, 512], + [256 ,256 ,3, 3, 512, 512], + [256 ,256 ,3, 3, 256, 512], + [256 ,256 ,3, 3, 256, 256], + [256 ,256 ,3, 3, 512, 256], + [512 ,512 ,3, 3, 256, 256], + [512 ,512 ,3, 3, 128, 256], + [512 ,512 ,3, 3, 128, 128], + [512 ,512 ,3, 3, 128, 3], + ] + test_conv_cnt = 0 + for configs in vae_model_conv: + self._test_conv_bias( + input_dim_x=configs[0], + input_dim_y=configs[1], + weight_dim_x=configs[2], + weight_dim_y=configs[3], + input_channels=configs[4], + output_channels=configs[5], + copy_op=False, + test_name=f"conv2d_bias_{dtype}_{test_conv_cnt}", + dtype=dtype, + ) + + self._test_conv_bias( + input_dim_x=configs[0], + input_dim_y=configs[1], + weight_dim_x=configs[2], + weight_dim_y=configs[3], + input_channels=configs[4], + output_channels=configs[5], + copy_op=True, + test_name=f"conv2d_bias_{dtype}_{test_conv_cnt}_copy_op", + dtype=dtype, + ) if __name__ == "__main__": diff --git a/tests/unittest/ops/test_gemm.py b/tests/unittest/ops/test_gemm.py index d9585c249..97d1f0800 100644 --- a/tests/unittest/ops/test_gemm.py +++ b/tests/unittest/ops/test_gemm.py @@ -81,6 +81,7 @@ def _test_rcr(self, ms, k, n, test_name, dtype="float16"): def test_rcr_simple_static(self) -> None: self._test_rcr([1024], 256, 512, "static") +<<<<<<< HEAD def test_rcr_simple_static_rocm(self) -> None: self._test_rcr([1024], 256, 512, "static_rocm") @@ -464,6 +465,276 @@ def test_rrr_sm90(self) -> None: filter_test_cases_by_test_env(GEMMTestCase) +======= +# def test_rcr_simple_static_rocm(self) -> None: +# self._test_rcr([1024], 256, 512, "static") + +# @parameterized.expand( +# [ +# ("dynamic1", [1, 1024], 256, 512), +# # TODO/FIXME: Fix the issue below. +# # There is some bug with floating point rounding, +# # e.g. the list of batch sizes like this [1, 99, 84, 987, 1024] +# # is not handled properly. +# ("dynamic2", [1, 99, 84, 1024], 128, 8), +# ("zero_k", [8], 0, 4), +# ("zero_m", [0], 8, 4), +# ] +# ) +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_rcr_simple_dynamic(self, name, ms, k, n) -> None: +# self._test_rcr(ms, k, n, name) + +# def _test_rcr_dynamic_n(self, ms, k, ns, test_name, dtype="float16"): +# target = detect_target() +# tolerance_limits = _TOLERANCE_LIMITS[dtype] +# X = Tensor( +# shape=[shape_utils.gen_int_var_min_max(ms), k], +# dtype=dtype, +# name="input_0", +# is_input=True, +# ) +# W = Tensor( +# shape=[shape_utils.gen_int_var_min_max(ns), k], +# dtype=dtype, +# name="input_1", +# is_input=True, +# ) +# OP = ops.gemm_rcr() +# Y = OP(X, W) +# Y._attrs["name"] = "output_0" +# Y._attrs["is_output"] = True +# module = compile_model( +# Y, target, "./tmp", f"gemm_rcr_{test_name}_{self._test_id}" +# ) +# self._test_id += 1 + +# for m in ms: +# for n in ns: +# X_pt = get_random_torch_tensor([m, k], dtype) +# W_pt = get_random_torch_tensor([n, k], dtype) +# Y_pt = torch.nn.functional.linear(X_pt, W_pt) + +# inputs = {"input_0": X_pt, "input_1": W_pt} +# y = get_torch_empty_tensor([m, n], dtype) +# module.run_with_tensors(inputs, [y]) + +# if X_pt.nelement() == 0 or W_pt.nelement() == 0: +# pass +# else: +# torch.testing.assert_close(Y_pt, y, **tolerance_limits) + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_rcr_dynamic_n(self): +# self._test_rcr([16, 1 * 29, 64], 256, 300000, "einsum_1") +# self._test_rcr_dynamic_n( +# [16, 1 * 29, 64], 256, [100000, 300000], "einsum_dynamic_n" +# ) + +# def _test_3d_2d_rcr(self, m0s, m1s, k, n, test_name, dtype="float16"): +# target = detect_target() +# tolerance_limits = _TOLERANCE_LIMITS[dtype] +# if dtype == "float16": +# tolerance_limits["atol"] = 2e-2 +# tolerance_limits["rtol"] = 2e-2 +# X = Tensor( +# shape=[ +# shape_utils.gen_int_var_min_max(m0s), +# shape_utils.gen_int_var_min_max(m1s), +# k, +# ], +# dtype=dtype, +# name="input_0", +# is_input=True, +# ) +# X._attrs["is_input"] = True +# W = Tensor(shape=[n, k], dtype=dtype, name="input_1", is_input=True) +# OP = ops.gemm_rcr() +# Y = OP(X, W) +# Y._attrs["name"] = "output_0" +# Y._attrs["is_output"] = True +# module = compile_model( +# Y, target, "./tmp", f"gemm_3d_2d_rcr_{test_name}_{self._test_id}" +# ) +# self._test_id += 1 + +# for m0, m1 in itertools.product(m0s, m1s): +# X_pt = get_random_torch_tensor([m0, m1, k], dtype) +# W_pt = get_random_torch_tensor([n, k], dtype) +# Y_pt = torch.nn.functional.linear(X_pt, W_pt) + +# inputs = {"input_0": X_pt, "input_1": W_pt} +# y = get_torch_empty_tensor([m0, m1, n], dtype) +# module.run_with_tensors(inputs, [y]) +# torch.testing.assert_close(Y_pt, y, **tolerance_limits) + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_3d_2d_rcr(self): +# self._test_3d_2d_rcr([1024], [2], 256, 512, "static") +# self._test_3d_2d_rcr([1, 1024], [2], 256, 512, "dynamic1") +# self._test_3d_2d_rcr([3], [128, 256], 256, 512, "dynamic2") +# self._test_3d_2d_rcr([1, 99, 1024], [1, 2], 128, 8, "dynamic3") + +# def _test_rrr(self, ms, k, n, test_name, dtype="float16"): +# target = detect_target() +# tolerance_limits = _TOLERANCE_LIMITS[dtype] +# if dtype == "float16": +# tolerance_limits["atol"] = 2e-2 +# tolerance_limits["rtol"] = 2e-2 +# X = Tensor( +# shape=[shape_utils.gen_int_var_min_max(ms), k], +# dtype=dtype, +# name="input_0", +# is_input=True, +# ) +# W = Tensor(shape=[k, n], dtype=dtype, name="input_1", is_input=True) +# OP = ops.gemm_rrr() +# Y = OP(X, W) +# Y._attrs["name"] = "output_0" +# Y._attrs["is_output"] = True +# module = compile_model( +# Y, target, "./tmp", f"gemm_rrr_{test_name}_{self._test_id}" +# ) +# self._test_id += 1 + +# for m in ms: +# X_pt = get_random_torch_tensor([m, k], dtype) +# W_pt = get_random_torch_tensor([k, n], dtype) +# Y_pt = torch.matmul(X_pt, W_pt) +# inputs = {"input_0": X_pt, "input_1": W_pt} +# y = get_torch_empty_tensor([m, n], dtype) +# module.run_with_tensors(inputs, [y]) +# torch.testing.assert_close(Y_pt, y, **tolerance_limits) + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_rrr(self): +# self._test_rrr([256], 128, 32, "static") +# self._test_rrr([1, 99, 1024, 2048], 256, 16, "dynamic") + +# def test_rrr_rocm(self): +# self._test_rrr([256], 128, 32, "static") + +# def _test_3d_2d_rrr(self, m0s, m1s, k, n, test_name, dtype="float16"): +# target = detect_target() +# tolerance_limits = {"atol": 2e-1, "rtol": 2e-1} +# X = Tensor( +# shape=[ +# shape_utils.gen_int_var_min_max(m0s), +# shape_utils.gen_int_var_min_max(m1s), +# k, +# ], +# dtype=dtype, +# name="input_0", +# is_input=True, +# ) +# W = Tensor(shape=[k, n], dtype=dtype, name="input_1", is_input=True) +# OP = ops.gemm_rrr() +# Y = OP(X, W) +# Y._attrs["name"] = "output_0" +# Y._attrs["is_output"] = True +# module = compile_model( +# Y, target, "./tmp", f"gemm_rrr_{test_name}_{self._test_id}" +# ) +# self._test_id += 1 + +# for m0, m1 in itertools.product(m0s, m1s): +# X_pt = get_random_torch_tensor([m0, m1, k], dtype) +# W_pt = get_random_torch_tensor([k, n], dtype) +# Y_pt = torch.matmul(X_pt, W_pt) + +# inputs = {"input_0": X_pt, "input_1": W_pt} +# y = get_torch_empty_tensor([m0, m1, n], dtype) +# module.run_with_tensors(inputs, [y]) +# torch.testing.assert_close(Y_pt, y, **tolerance_limits) + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_3d_2d_rrr(self): +# self._test_3d_2d_rrr([256], [2], 128, 32, "static") +# self._test_3d_2d_rrr([1, 128], [3], 256, 16, "dynamic1") +# self._test_3d_2d_rrr([2], [24, 36], 256, 16, "dynamic2") +# self._test_3d_2d_rrr([2, 34, 48], [1, 3, 5], 256, 16, "dynamic3") + +# def _test_h_rcr(self, ait_dtype): +# M = 256 +# K = 256 +# N = 512 +# target = detect_target(use_fp16_acc=(ait_dtype == "float16")) +# X = Tensor(shape=[M, K], dtype=ait_dtype, name="input_0", is_input=True) +# W = Tensor(shape=[N, K], dtype=ait_dtype, name="input_1", is_input=True) +# OP = ops.gemm_rcr() +# Y = OP(X, W) +# Y._attrs["name"] = "output_0" +# Y._attrs["is_output"] = True +# module = compile_model( +# Y, target, "./tmp", f"hgemm_rcr_{ait_dtype}_{self._test_id}" +# ) +# self._test_id += 1 +# X_pt = get_random_torch_tensor((M, K), ait_dtype) +# W_pt = get_random_torch_tensor((N, K), ait_dtype) +# Y_pt = torch.nn.functional.linear(X_pt, W_pt) + +# inputs = {"input_0": X_pt, "input_1": W_pt} +# y = get_torch_empty_tensor((M, N), ait_dtype) +# module.run_with_tensors(inputs, [y]) +# torch.testing.assert_close(Y_pt, y, atol=1e-1, rtol=1e-1) + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_h_rcr_float16(self): +# self._test_h_rcr(ait_dtype="float16") + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_h_rcr_float32_sm80(self): +# self._test_h_rcr(ait_dtype="float32") + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_h_rcr_bfloat16_bf16(self): +# self._test_h_rcr(ait_dtype="bfloat16") + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_gemm_float32_sm80(self): +# self._test_rcr([1024], 256, 512, "static_float", dtype="float32") +# self._test_rcr([1, 1024], 256, 512, "dynamic1_float", dtype="float32") +# self._test_rcr([16, 1 * 29, 64], 256, 300000, "einsum_1_float", dtype="float32") + +# self._test_3d_2d_rcr([1024], [2], 256, 512, "static_float", dtype="float32") +# self._test_3d_2d_rcr( +# [1, 99, 1024], [1, 2], 128, 8, "dynamic3_float", dtype="float32" +# ) + +# self._test_rrr([256], 128, 32, "static_float", dtype="float32") +# self._test_rrr([1, 99, 1024, 2048], 256, 16, "dynamic_float", dtype="float32") + +# self._test_3d_2d_rrr([256], [2], 128, 32, "static_float", dtype="float32") +# self._test_3d_2d_rrr( +# [2, 34, 48], [1, 3, 5], 256, 16, "dynamic3_float", dtype="float32" +# ) + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# def test_gemm_bfloat16_bf16(self): +# self._test_rcr([1024], 256, 512, "static_bfloat16", dtype="bfloat16") +# self._test_rcr([1, 1024], 256, 512, "dynamic1_bfloat16", dtype="bfloat16") +# self._test_rcr( +# [16, 1 * 29, 64], 256, 300000, "einsum_1_bfloat16", dtype="bfloat16" +# ) + +# self._test_3d_2d_rcr([1024], [2], 256, 512, "static_bfloat16", dtype="bfloat16") +# self._test_3d_2d_rcr( +# [1, 99, 1024], [1, 2], 128, 8, "dynamic3_bfloat16", dtype="bfloat16" +# ) + +# self._test_rrr([256], 128, 32, "static_bfloat16", dtype="bfloat16") +# self._test_rrr( +# [1, 99, 1024, 2048], 256, 16, "dynamic_bfloat16", dtype="bfloat16" +# ) + +# self._test_3d_2d_rrr([256], [2], 128, 32, "static_bfloat16", dtype="bfloat16") +# self._test_3d_2d_rrr( +# [2, 34, 48], [1, 3, 5], 256, 16, "dynamic3_bfloat16", dtype="bfloat16" +# ) + + +# filter_test_cases_by_test_env(GEMMTestCase) +>>>>>>> origin/navi3_rel_ver_1.0 if __name__ == "__main__": diff --git a/tests/unittest/ops/test_gemm_bias.py b/tests/unittest/ops/test_gemm_bias.py index cd276b739..3f626772f 100644 --- a/tests/unittest/ops/test_gemm_bias.py +++ b/tests/unittest/ops/test_gemm_bias.py @@ -77,14 +77,16 @@ def _test_rcr(self, Ms, N, K, test_name, dtype="float16"): def test_rcr_zero_size(self): target = detect_target() - # This test triggered a c10 assertion failure internally - # caffe2/c10/util/SmallVector.h:338: - # Assertion `idx < size()' failed - if type(target).__name__ != "FBCUDA": - self._test_rcr([2], N=64, K=0, test_name="zero_k") - self._test_rcr([2], N=0, K=4, test_name="zero_n") - self._test_rcr([0], N=4, K=4, test_name="zero_m") - + if target.name() == "cuda": + # This test triggered a c10 assertion failure internally + # caffe2/c10/util/SmallVector.h:338: + # Assertion `idx < size()' failed + if type(target).__name__ != "FBCUDA": + self._test_rcr([2], N=64, K=0, test_name="zero_k") + self._test_rcr([2], N=0, K=4, test_name="zero_n") + self._test_rcr([0], N=4, K=4, test_name="zero_m") + + @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") def test_rcr_static(self): self._test_rcr([4096], N=4, K=4, test_name="static") self._test_rcr([1000], N=81, K=1024, test_name="static") @@ -95,6 +97,7 @@ def test_rcr_static_rocm(self): self._test_rcr([1000], N=81, K=1024, test_name="static") self._test_rcr([67200], N=3, K=256, test_name="static") + @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") def test_rcr_bfloat16_bf16(self): dtype = "bfloat16" self._test_rcr([4], N=2, K=11, test_name=f"static_{dtype}", dtype=dtype) diff --git a/tests/unittest/ops/test_gemm_bias_swish.py b/tests/unittest/ops/test_gemm_bias_swish.py index c51c76d78..f83619f91 100644 --- a/tests/unittest/ops/test_gemm_bias_swish.py +++ b/tests/unittest/ops/test_gemm_bias_swish.py @@ -72,6 +72,7 @@ def _test_gemm_rcr_bias_swish( inputs = {"input_0": X_pt, "input_1": W_pt, "input_2": B_pt} y = get_torch_empty_tensor([M, N], dtype) module.run_with_tensors(inputs, [y]) +<<<<<<< HEAD torch.testing.assert_close(Y_pt, y, **_TOLERANCE_LIMITS[dtype]) def test_gemm_rcr_bias_swish_fp16(self): @@ -119,6 +120,18 @@ def test_gemm_rcr_bias_swish_sm90(self): dtype="bfloat16", test_suffix="bfloat16_force_sm90", ) +======= + self.assertTrue(torch.allclose(Y_pt, y, **_TOLERANCE_LIMITS[dtype])) + + def test_rcr_float16(self): + self._test_rcr(dtype="float16") + + # def test_rcr_float32_sm80(self): + # self._test_rcr(dtype="float32") + + # def test_rcr_bfloat16_bf16(self): + # self._test_rcr(dtype="bfloat16") +>>>>>>> origin/navi3_rel_ver_1.0 filter_test_cases_by_test_env(GEMMBiasSwishTestCase) diff --git a/tests/unittest/ops/test_groupnorm.py b/tests/unittest/ops/test_groupnorm.py index fb65ddfd1..6bec17935 100644 --- a/tests/unittest/ops/test_groupnorm.py +++ b/tests/unittest/ops/test_groupnorm.py @@ -145,6 +145,17 @@ def test_groupnorm_swish(self): self._test_groupnorm( x_shape=[1, 512, 512, 256], num_groups=32, eps=1e-5, use_swish=True ) + # vae model groupnorm+swish + vae_shapes = [ + (1, 64, 64, 512), + (1, 128, 128, 512), + (1, 256, 256, 512), + (1, 256, 256, 256), + (1, 512, 512, 256), + (1, 512, 512, 128), + ] + for shape in vae_shapes: + self._test_groupnorm(x_shape=shape, num_groups=32, eps=1e-5, use_swish=True) # For benchmark only. # shapes = [