From 784206fde23f9f5153d328891125457c27739186 Mon Sep 17 00:00:00 2001 From: Ivan Mikhnenkov <39604625+ivanmikhnenkov@users.noreply.github.com> Date: Thu, 27 Oct 2022 01:20:57 +0300 Subject: [PATCH 001/638] updated to 5th stable diffusion checkpoint (#57) * updated to 5th stable diffusion checkpoint * updated all stable diffusion example files to checkpoint v1.5 --- examples/05_stable_diffusion/benchmark.py | 2 +- examples/05_stable_diffusion/benchmark_pt.py | 2 +- examples/05_stable_diffusion/compile.py | 2 +- examples/05_stable_diffusion/demo.py | 2 +- examples/05_stable_diffusion/demo_img2img.py | 2 +- examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py | 2 +- .../pipeline_stable_diffusion_img2img_ait.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/05_stable_diffusion/benchmark.py b/examples/05_stable_diffusion/benchmark.py index bda7da289..811743da9 100644 --- a/examples/05_stable_diffusion/benchmark.py +++ b/examples/05_stable_diffusion/benchmark.py @@ -288,7 +288,7 @@ def benchmark_diffusers(token, batch_size, verify, benchmark_pt): access_token = token pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=access_token, diff --git a/examples/05_stable_diffusion/benchmark_pt.py b/examples/05_stable_diffusion/benchmark_pt.py index 3534eaf62..13b8738cc 100644 --- a/examples/05_stable_diffusion/benchmark_pt.py +++ b/examples/05_stable_diffusion/benchmark_pt.py @@ -27,7 +27,7 @@ ) def run(token, prompt, benchmark): pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=token, diff --git a/examples/05_stable_diffusion/compile.py b/examples/05_stable_diffusion/compile.py index 4c6288a84..513df5b9b 100644 --- a/examples/05_stable_diffusion/compile.py +++ b/examples/05_stable_diffusion/compile.py @@ -333,7 +333,7 @@ def compile_diffusers(token, batch_size, img2img=False, use_fp16_acc=True, conve access_token = token pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=access_token, diff --git a/examples/05_stable_diffusion/demo.py b/examples/05_stable_diffusion/demo.py index 5a7b8b79e..cef5c7aaa 100644 --- a/examples/05_stable_diffusion/demo.py +++ b/examples/05_stable_diffusion/demo.py @@ -27,7 +27,7 @@ ) def run(token, prompt, benchmark): pipe = StableDiffusionAITPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=token, diff --git a/examples/05_stable_diffusion/demo_img2img.py b/examples/05_stable_diffusion/demo_img2img.py index 5a9f8d0d6..65bdaa874 100644 --- a/examples/05_stable_diffusion/demo_img2img.py +++ b/examples/05_stable_diffusion/demo_img2img.py @@ -35,7 +35,7 @@ def run(token, prompt, benchmark): # load the pipeline device = "cuda" - model_id_or_path = "CompVis/stable-diffusion-v1-4" + model_id_or_path = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionImg2ImgAITPipeline.from_pretrained( model_id_or_path, revision="fp16", diff --git a/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py b/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py index bf4450e22..5234117b1 100644 --- a/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py +++ b/examples/05_stable_diffusion/pipeline_stable_diffusion_ait.py @@ -60,7 +60,7 @@ class StableDiffusionAITPipeline(StableDiffusionPipeline): [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ diff --git a/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py b/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py index 9d18a7d32..d6c75ab05 100644 --- a/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py +++ b/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py @@ -70,7 +70,7 @@ class StableDiffusionImg2ImgAITPipeline(StableDiffusionImg2ImgPipeline): [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offsensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ From 0965ed07fff0cc36a6b3933f2d4aa4e187892ad8 Mon Sep 17 00:00:00 2001 From: Chris Kitching Date: Mon, 7 Nov 2022 02:22:42 +0000 Subject: [PATCH 002/638] Support different sizes via recompilation (StableDiff demo) (#71) Mostly, this commit is just re-establishing the relationship between various previously-hardcoded constants and the target image size (since the latent size is 1/8 of the image size, hardcoding the latent sizes is inconvenient). This adds `--width` and `--height` options to both compile.py and demo.py, and provided these both match you can process different sizes. For img2img mode, the size options passed at compile time must match the size of the actual input image. Consequently, the `--img2img` flag for `compile.py` no longer exists: all this ever did was change the hardcoded size to match the default input image used by `demo_img2img.py`. Yikes. Sooo it's slightly more flexible than before, but still has no support for a single binary to handle different image sizes. It isn't super clear that compiling a generic binary is useful: the upstream project can do that just fine: isn't the whole point of AITemplates to achieve performance gains via aggressive constant propagation and benchmarking to select the optimal kernels? --- examples/05_stable_diffusion/compile.py | 13 ++++++++----- examples/05_stable_diffusion/demo.py | 6 ++++-- examples/05_stable_diffusion/demo_img2img.py | 6 ++++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/examples/05_stable_diffusion/compile.py b/examples/05_stable_diffusion/compile.py index 513df5b9b..e91af3bb9 100644 --- a/examples/05_stable_diffusion/compile.py +++ b/examples/05_stable_diffusion/compile.py @@ -316,11 +316,12 @@ def compile_vae( @click.command() @click.option("--token", default="", help="access token") +@click.option("--width", default=512, help="Width of generated image") +@click.option("--height", default=512, help="Height of generated image") @click.option("--batch-size", default=1, help="batch size") -@click.option("--img2img", default=False, help="compile img2img models") @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") -def compile_diffusers(token, batch_size, img2img=False, use_fp16_acc=True, convert_conv_to_gemm=True): +def compile_diffusers(token, width, height, batch_size, use_fp16_acc=True, convert_conv_to_gemm=True): logging.getLogger().setLevel(logging.INFO) np.random.seed(0) torch.manual_seed(4896) @@ -339,19 +340,21 @@ def compile_diffusers(token, batch_size, img2img=False, use_fp16_acc=True, conve use_auth_token=access_token, ).to("cuda") - width = 96 if img2img else 64 + ww = width // 8 + hh = height // 8 # CLIP compile_clip(batch_size=batch_size, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) # UNet compile_unet( batch_size=batch_size * 2, - ww=width, + ww=ww, + hh=hh, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm, ) # VAE - compile_vae(batch_size=batch_size, width=width, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) + compile_vae(batch_size=batch_size, width=ww, height=hh, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) if __name__ == "__main__": diff --git a/examples/05_stable_diffusion/demo.py b/examples/05_stable_diffusion/demo.py index cef5c7aaa..51859e886 100644 --- a/examples/05_stable_diffusion/demo.py +++ b/examples/05_stable_diffusion/demo.py @@ -21,11 +21,13 @@ @click.command() @click.option("--token", default="", help="access token") +@click.option("--width", default=512, help="Width of generated image") +@click.option("--height", default=512, help="Height of generated image") @click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt") @click.option( "--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark" ) -def run(token, prompt, benchmark): +def run(token, width, height, prompt, benchmark): pipe = StableDiffusionAITPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", revision="fp16", @@ -34,7 +36,7 @@ def run(token, prompt, benchmark): ).to("cuda") with torch.autocast("cuda"): - image = pipe(prompt).images[0] + image = pipe(prompt, height, width).images[0] if benchmark: t = benchmark_torch_function(10, pipe, prompt) print(f"sd e2e: {t} ms") diff --git a/examples/05_stable_diffusion/demo_img2img.py b/examples/05_stable_diffusion/demo_img2img.py index 65bdaa874..844aac726 100644 --- a/examples/05_stable_diffusion/demo_img2img.py +++ b/examples/05_stable_diffusion/demo_img2img.py @@ -25,13 +25,15 @@ @click.command() @click.option("--token", default="", help="access token") +@click.option("--width", default=512, help="Width of generated image") +@click.option("--height", default=512, help="Height of generated image") @click.option( "--prompt", default="A fantasy landscape, trending on artstation", help="prompt" ) @click.option( "--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark" ) -def run(token, prompt, benchmark): +def run(token, width, height, prompt, benchmark): # load the pipeline device = "cuda" @@ -49,7 +51,7 @@ def run(token, prompt, benchmark): response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") - init_image = init_image.resize((768, 512)) + init_image = init_image.resize((height, width)) with torch.autocast("cuda"): images = pipe( From f7878c907167b41423d63e5355fc2685cda58f8e Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 9 Nov 2022 12:55:16 -0800 Subject: [PATCH 003/638] v0.1.1 (#74) * v0.11 * update cutlass * fix * add missing files * patch cutlass Co-authored-by: Bing Xu --- .gitmodules | 2 +- 3rdparty/cutlass | 2 +- examples/03_bert/benchmark_mi250.sh | 6 +- examples/05_stable_diffusion/benchmark_pt.py | 1 + examples/05_stable_diffusion/compile.py | 18 +- examples/05_stable_diffusion/demo_img2img.py | 2 +- .../05_stable_diffusion/modeling/attention.py | 1 + examples/05_stable_diffusion/modeling/clip.py | 12 +- .../pipeline_stable_diffusion_img2img_ait.py | 4 +- python/aitemplate/__init__.py | 18 +- python/aitemplate/backend/backend_spec.py | 50 +- python/aitemplate/backend/builder.py | 102 +- python/aitemplate/backend/codegen.py | 92 +- .../backend/common/concatenate_common.py | 58 +- .../backend/common/elementwise_common.py | 70 +- .../aitemplate/backend/common/split_common.py | 51 +- .../backend/common/tensor/argmax_common.py | 10 +- .../common/tensor/batch_gather_common.py | 26 +- .../common/tensor/permute021_common.py | 39 +- .../common/tensor/permute102_common.py | 37 +- .../common/tensor/permute210_common.py | 39 +- .../backend/common/tensor/slice_common.py | 48 +- .../backend/common/tensor/topk_common.py | 27 +- .../backend/common/tensor_accessor_codegen.py | 52 +- .../backend/common/upsampling2d_common.py | 45 +- .../common/vision_ops/efficient_nms_common.py | 63 +- .../common/vision_ops/efficient_nms_kernel.py | 2 +- .../backend/common/vision_ops/nms_common.py | 48 +- .../common/vision_ops/roi_align_common.py | 45 +- python/aitemplate/backend/cuda/__init__.py | 1 + .../backend/cuda/attention/__init__.py | 4 +- .../backend/cuda/attention/flash_attention.py | 20 +- .../cuda/attention/mem_eff_attention.py | 262 ++ .../backend/cuda/attention/src/fmha.h | 14 - .../backend/cuda/attention/src/fmha/gemm.h | 14 - .../cuda/attention/src/fmha/gmem_tile.h | 14 - .../cuda/attention/src/fmha/kernel_traits.h | 14 - .../backend/cuda/attention/src/fmha/mask.h | 14 - .../cuda/attention/src/fmha/smem_tile.h | 14 - .../backend/cuda/attention/src/fmha/softmax.h | 14 - .../backend/cuda/attention/src/fmha/utils.h | 14 - .../src/fmha_block_fprop_fp16_kernel.sm80.cu | 14 - .../src/fmha_block_fprop_kernel_1xN.h | 14 - .../cuda/attention/src/fmha_blockmask.h | 14 - .../src/fmha_fprop_fp16_kernel.sm80.cu | 14 - .../attention/src/fmha_fprop_kernel_1xN.h | 14 - .../backend/cuda/attention/src/fmha_kernel.h | 14 - .../backend/cuda/attention/src/fmha_utils.h | 14 - .../backend/cuda/attention/src/philox.cuh | 14 - .../aitemplate/backend/cuda/conv2d/common.py | 18 +- .../conv2d/common_conv2d_bias_activation.py | 59 +- .../common_conv2d_bias_add_activation.py | 65 +- .../aitemplate/backend/cuda/conv2d/conv2d.py | 47 +- .../backend/cuda/conv2d/conv2d_bias.py | 2 +- .../backend/cuda/conv2d/conv2d_bias_add.py | 2 +- .../cuda/conv2d/conv2d_bias_add_hardswish.py | 2 +- .../cuda/conv2d/conv2d_bias_add_relu.py | 2 +- .../cuda/conv2d/conv2d_bias_few_channels.py | 2 +- .../cuda/conv2d/conv2d_bias_hardswish.py | 2 +- .../conv2d_bias_hardswish_few_channels.py | 2 +- .../backend/cuda/conv2d/conv2d_bias_relu.py | 2 +- .../conv2d/conv2d_bias_relu_few_channels.py | 2 +- .../cuda/conv2d/conv2d_bias_sigmoid.py | 2 +- .../backend/cuda/conv2d/transposed_conv2d.py | 16 +- .../cuda/conv2d/transposed_conv2d_bias.py | 17 +- .../backend/cuda/conv3d/__init__.py | 20 + .../aitemplate/backend/cuda/conv3d/common.py | 364 +++ .../aitemplate/backend/cuda/conv3d/conv3d.py | 496 ++++ .../backend/cuda/conv3d/depthwise_conv3d.py | 331 +++ .../backend/cuda/elementwise/__init__.py | 4 +- .../backend/cuda/elementwise/custom_math.cuh | 158 ++ .../cuda/elementwise/fused_elementwise.py | 1 + .../cuda/elementwise/int_elementwise.py | 67 + .../backend/cuda/embedding/bert_embeddings.py | 125 +- .../cuda/gemm_epilogue_vistor/__init__.py | 16 +- .../bmm_common_softmax.py | 4 +- .../gemm_epilogue_vistor/bmm_rcr_softmax.py | 2 +- .../gemm_epilogue_vistor/common_dual_gemm.py | 458 ++++ .../gemm_epilogue_vistor/common_softmax.py | 8 +- .../dual_gemm_rcr_fast_gelu.py | 348 +++ .../dual_gemm_rcr_silu.py | 220 ++ .../gemm_epilogue_vistor/gemm_rcr_softmax.py | 2 +- .../include/gemm_with_softmax.h | 14 - .../backend/cuda/gemm_special/bmm_rcr_n1.py | 33 +- .../cuda/gemm_special/bmm_rrr_k1_tanh.py | 83 +- .../cuda/gemm_special/gemm_rrr_small_nk.py | 81 +- .../backend/cuda/gemm_universal/__init__.py | 1 + .../backend/cuda/gemm_universal/bmm_ccr.py | 35 +- .../cuda/gemm_universal/bmm_ccr_add.py | 19 +- .../backend/cuda/gemm_universal/bmm_common.py | 182 +- .../backend/cuda/gemm_universal/bmm_crr.py | 35 +- .../cuda/gemm_universal/bmm_crr_add.py | 21 +- .../cuda/gemm_universal/bmm_permute_common.py | 153 +- .../backend/cuda/gemm_universal/bmm_rcr.py | 36 +- .../cuda/gemm_universal/bmm_rcr_permute.py | 36 +- .../backend/cuda/gemm_universal/bmm_rrr.py | 35 +- .../cuda/gemm_universal/bmm_rrr_add.py | 21 +- .../cuda/gemm_universal/bmm_rrr_permute.py | 40 +- .../backend/cuda/gemm_universal/common.py | 433 ++- .../cuda/gemm_universal/common_bias.py | 22 +- .../gemm_universal/common_bias_activation.py | 21 +- .../gemm_universal/common_bias_broadcast.py | 248 +- .../cuda/gemm_universal/common_no_bias.py | 105 + .../cuda/gemm_universal/common_permute.py | 215 +- .../backend/cuda/gemm_universal/gemm_rcr.py | 39 +- .../cuda/gemm_universal/gemm_rcr_bias.py | 37 +- .../cuda/gemm_universal/gemm_rcr_bias_add.py | 5 +- .../gemm_universal/gemm_rcr_bias_add_add.py | 5 +- .../gemm_rcr_bias_add_add_relu.py | 5 +- .../gemm_universal/gemm_rcr_bias_add_relu.py | 5 +- .../gemm_universal/gemm_rcr_bias_fast_gelu.py | 11 +- .../cuda/gemm_universal/gemm_rcr_bias_gelu.py | 11 +- .../gemm_universal/gemm_rcr_bias_hardswish.py | 11 +- .../cuda/gemm_universal/gemm_rcr_bias_mul.py | 5 +- .../gemm_universal/gemm_rcr_bias_mul_add.py | 5 +- .../gemm_universal/gemm_rcr_bias_mul_tanh.py | 5 +- .../gemm_universal/gemm_rcr_bias_permute.py | 23 +- .../cuda/gemm_universal/gemm_rcr_bias_relu.py | 11 +- .../gemm_universal/gemm_rcr_bias_sigmoid.py | 11 +- .../gemm_rcr_bias_sigmoid_mul.py | 5 +- .../gemm_rcr_bias_sigmoid_mul_tanh.py | 5 +- .../gemm_universal/gemm_rcr_bias_swish.py | 11 +- .../cuda/gemm_universal/gemm_rcr_bias_tanh.py | 11 +- .../cuda/gemm_universal/gemm_rcr_fast_gelu.py | 170 ++ .../cuda/gemm_universal/gemm_rcr_permute.py | 51 +- .../backend/cuda/gemm_universal/gemm_rrr.py | 41 +- .../cuda/gemm_universal/gemm_rrr_permute.py | 51 +- .../cuda/gemm_universal/group_common.py | 272 +- .../cuda/gemm_universal/group_common_bias.py | 18 +- .../cuda/gemm_universal/group_gemm_rcr.py | 16 +- .../gemm_universal/group_gemm_rcr_bias.py | 6 +- .../group_gemm_rcr_bias_relu.py | 6 +- .../group_gemm_rcr_bias_sigmoid.py | 6 +- .../cuda/gemm_universal/perm021fc_ccr.py | 35 +- .../cuda/gemm_universal/perm021fc_ccr_bias.py | 23 +- .../perm021fc_ccr_bias_permute.py | 29 +- .../cuda/gemm_universal/perm021fc_crc.py | 29 +- .../cuda/gemm_universal/perm021fc_crc_bias.py | 15 +- .../cuda/gemm_universal/perm102_bmm_rcr.py | 41 +- .../gemm_universal/perm102_bmm_rcr_bias.py | 27 +- .../cuda/gemm_universal/perm102_bmm_rrr.py | 43 +- .../gemm_universal/perm102_bmm_rrr_bias.py | 26 +- .../cuda/groupnorm/groupnorm_common.py | 51 +- .../cuda/groupnorm/groupnorm_kernel.cuh | 388 ++- .../backend/cuda/groupnorm/layer_norm.cuh | 2404 +++++++++++++++++ .../batch_layernorm_sigmoid_mul.py | 25 +- .../group_layernorm_sigmoid_mul.py | 97 +- .../layernorm_sigmoid_mul/layernorm_common.py | 9 +- .../layernorm_sigmoid_mul.py | 30 +- .../layernorm_sigmoid_mul_kernel.cuh | 163 +- .../aitemplate/backend/cuda/lib_template.py | 18 +- .../backend/cuda/padding/nhwc3to4.py | 35 +- .../backend/cuda/padding/nhwc3to8.py | 69 +- .../backend/cuda/padding/pad_last_dim.py | 51 +- .../backend/cuda/pool2d/avg_pool2d.py | 18 +- .../backend/cuda/pool2d/max_pool2d.py | 21 +- .../aitemplate/backend/cuda/pool2d/pool2d.py | 4 +- .../backend/cuda/reduce/reduce_3d.py | 49 +- .../backend/cuda/reduce/reduce_common.py | 71 +- .../backend/cuda/reduce/reduce_small_axis.py | 47 +- python/aitemplate/backend/cuda/reduce/var.py | 7 +- .../backend/cuda/softmax/softmax.py | 32 +- python/aitemplate/backend/cuda/target_def.py | 213 +- .../backend/cuda/tensor/__init__.py | 2 + .../aitemplate/backend/cuda/tensor/gather.py | 43 +- .../backend/cuda/tensor/permute.cuh | 369 +++ .../aitemplate/backend/cuda/tensor/permute.py | 183 ++ .../backend/cuda/upsample/upsampling2d.py | 6 +- .../backend/cuda/upsample/upsampling2d_add.py | 8 +- .../backend/cuda/view_ops/view_ops.py | 26 +- .../cuda/vision_ops/nms/batched_nms.py | 21 +- .../vision_ops/nms/batched_nms_kernel.cuh | 14 - .../cuda/vision_ops/roi_ops/roi_align.py | 7 +- .../cuda/vision_ops/roi_ops/roi_ops.py | 6 +- python/aitemplate/backend/main_templates.py | 197 +- python/aitemplate/backend/profiler_cache.py | 309 ++- python/aitemplate/backend/profiler_runner.py | 166 +- .../aitemplate/backend/rocm/conv2d/common.py | 14 +- .../aitemplate/backend/rocm/conv2d/conv2d.py | 2 +- .../backend/rocm/conv2d/conv2d_bias.py | 2 +- .../rocm/conv2d/conv2d_bias_add_relu.py | 2 +- .../backend/rocm/conv2d/conv2d_bias_relu.py | 2 +- .../rocm/conv2d/conv2d_bias_sigmoid.py | 2 +- .../backend/rocm/conv2d/transposed_conv2d.py | 2 +- .../conv2d/transposed_conv2d_bias_relu.py | 2 +- .../aitemplate/backend/rocm/gemm/bmm_ccr.py | 2 +- .../backend/rocm/gemm/bmm_common.py | 2 +- .../aitemplate/backend/rocm/gemm/bmm_crr.py | 2 +- .../aitemplate/backend/rocm/gemm/bmm_rcr.py | 2 +- .../backend/rocm/gemm/bmm_rcr_permute.py | 2 +- .../aitemplate/backend/rocm/gemm/bmm_rrr.py | 2 +- .../backend/rocm/gemm/bmm_rrr_permute.py | 2 +- .../backend/rocm/gemm/bmm_softmax_bmm.py | 2 +- .../rocm/gemm/bmm_softmax_bmm_permute.py | 2 +- python/aitemplate/backend/rocm/gemm/common.py | 11 +- .../aitemplate/backend/rocm/gemm/gemm_rcr.py | 2 +- .../backend/rocm/gemm/gemm_rcr_bias.py | 2 +- .../backend/rocm/gemm/gemm_rcr_bias_add.py | 2 +- .../rocm/gemm/gemm_rcr_bias_add_add.py | 2 +- .../rocm/gemm/gemm_rcr_bias_add_add_relu.py | 2 +- .../rocm/gemm/gemm_rcr_bias_add_relu.py | 2 +- .../rocm/gemm/gemm_rcr_bias_fast_gelu.py | 2 +- .../backend/rocm/gemm/gemm_rcr_bias_mul.py | 2 +- .../rocm/gemm/gemm_rcr_bias_mul_add.py | 2 +- .../rocm/gemm/gemm_rcr_bias_mul_tanh.py | 2 +- .../rocm/gemm/gemm_rcr_bias_permute.py | 2 +- .../rocm/gemm/gemm_rcr_bias_permute_m2n3.py | 2 +- .../rocm/gemm/gemm_rcr_bias_permute_m3n2.py | 2 +- .../backend/rocm/gemm/gemm_rcr_bias_relu.py | 2 +- .../rocm/gemm/gemm_rcr_bias_sigmoid.py | 2 +- .../rocm/gemm/gemm_rcr_bias_sigmoid_mul.py | 2 +- .../gemm/gemm_rcr_bias_sigmoid_mul_tanh.py | 2 +- .../backend/rocm/gemm/gemm_rcr_bias_swish.py | 2 +- .../backend/rocm/gemm/gemm_rcr_bias_tanh.py | 2 +- .../rocm/gemm/gemm_rcr_permute_m2n3.py | 2 +- .../aitemplate/backend/rocm/gemm/gemm_rrr.py | 2 +- .../rocm/gemm/gemm_rrr_bias_permute.py | 2 +- .../aitemplate/backend/rocm/lib_template.py | 8 +- .../backend/rocm/normalization/groupnorm.py | 2 +- .../rocm/normalization/groupnorm_swish.py | 2 +- .../backend/rocm/normalization/layernorm.py | 2 +- .../backend/rocm/normalization/norm_common.py | 11 +- .../backend/rocm/normalization/softmax.py | 2 +- python/aitemplate/backend/rocm/target_def.py | 146 +- python/aitemplate/backend/target.py | 42 +- python/aitemplate/compiler/__init__.py | 3 +- python/aitemplate/compiler/base.py | 92 +- python/aitemplate/compiler/compiler.py | 56 +- python/aitemplate/compiler/dtype.py | 136 + python/aitemplate/compiler/model.py | 73 +- .../compiler/ops/attention/__init__.py | 3 +- .../compiler/ops/attention/flash_attention.py | 10 + .../ops/attention/mem_eff_attention.py | 179 ++ .../compiler/ops/common/__init__.py | 1 + .../compiler/ops/common/elementwise.py | 6 + .../compiler/ops/common/epilogue.py | 4 + .../compiler/ops/common/fused_elementwise.py | 5 +- .../compiler/ops/common/int_elementwise.py | 142 + python/aitemplate/compiler/ops/common/math.py | 16 + .../compiler/ops/common/view_ops.py | 135 +- .../aitemplate/compiler/ops/conv/__init__.py | 2 + .../compiler/ops/conv/cache_entry.py | 65 + .../ops/conv/common_conv2d_bias_activation.py | 6 + .../conv/common_conv2d_bias_add_activation.py | 6 + python/aitemplate/compiler/ops/conv/conv2d.py | 18 +- .../compiler/ops/conv/conv2d_bias.py | 10 + .../compiler/ops/conv/conv2d_bias_add.py | 6 + .../ops/conv/conv2d_bias_add_hardswish.py | 6 + .../compiler/ops/conv/conv2d_bias_add_relu.py | 6 + .../ops/conv/conv2d_bias_few_channels.py | 8 +- .../ops/conv/conv2d_bias_hardswish.py | 6 + .../conv2d_bias_hardswish_few_channels.py | 6 + .../compiler/ops/conv/conv2d_bias_relu.py | 6 + .../ops/conv/conv2d_bias_relu_few_channels.py | 8 +- .../compiler/ops/conv/conv2d_bias_sigmoid.py | 6 + python/aitemplate/compiler/ops/conv/conv3d.py | 623 +++++ .../compiler/ops/conv/depthwise_conv3d.py | 290 ++ .../conv/special_conv2d_bias_activation.py | 13 + .../ops/gemm_epilogue_vistor/__init__.py | 10 +- .../dual_gemm_rcr_fast_gelu.py | 77 + .../dual_gemm_rcr_silu.py | 77 + .../compiler/ops/gemm_special/bmm_rcr_n1.py | 1 - .../compiler/ops/gemm_universal/__init__.py | 1 + .../ops/gemm_universal/bmm_rcr_permute.py | 10 +- .../ops/gemm_universal/bmm_rrr_permute.py | 9 +- .../ops/gemm_universal/bmm_softmax_bmm.py | 3 + .../gemm_universal/bmm_softmax_bmm_permute.py | 11 +- .../ops/gemm_universal/gemm_common.py | 291 +- .../compiler/ops/gemm_universal/gemm_rcr.py | 1 - .../ops/gemm_universal/gemm_rcr_bias.py | 1 - .../gemm_universal/gemm_rcr_bias_fast_gelu.py | 1 - .../ops/gemm_universal/gemm_rcr_bias_gelu.py | 1 - .../gemm_universal/gemm_rcr_bias_hardswish.py | 1 - .../gemm_universal/gemm_rcr_bias_permute.py | 9 +- .../ops/gemm_universal/gemm_rcr_bias_relu.py | 1 - .../gemm_universal/gemm_rcr_bias_sigmoid.py | 1 - .../ops/gemm_universal/gemm_rcr_bias_swish.py | 1 - .../ops/gemm_universal/gemm_rcr_bias_tanh.py | 1 - .../ops/gemm_universal/gemm_rcr_fast_gelu.py | 41 + .../ops/gemm_universal/gemm_rcr_permute.py | 37 +- .../compiler/ops/gemm_universal/gemm_rrr.py | 1 - .../gemm_universal/gemm_rrr_bias_permute.py | 3 +- .../ops/gemm_universal/gemm_rrr_permute.py | 8 +- .../ops/gemm_universal/group_gemm_rcr.py | 13 +- .../ops/gemm_universal/group_gemm_rcr_bias.py | 5 +- .../group_gemm_rcr_bias_relu.py | 1 - .../group_gemm_rcr_bias_sigmoid.py | 1 - .../ops/gemm_universal/perm021fc_ccr.py | 1 - .../ops/gemm_universal/perm021fc_ccr_bias.py | 1 - .../perm021fc_ccr_bias_permute.py | 4 +- .../ops/gemm_universal/perm102_bmm_rcr.py | 1 - .../gemm_universal/perm102_bmm_rcr_bias.py | 1 - .../ops/gemm_universal/perm102_bmm_rrr.py | 1 - .../gemm_universal/perm102_bmm_rrr_bias.py | 1 - .../compiler/ops/groupnorm/groupnorm.py | 23 +- .../compiler/ops/layernorm/group_layernorm.py | 1 - .../layernorm/group_layernorm_sigmoid_mul.py | 1 - .../compiler/ops/layernorm/layernorm.py | 38 +- .../ops/layernorm/layernorm_sigmoid_mul.py | 17 +- .../compiler/ops/padding/nhwc_pad_common.py | 6 + .../compiler/ops/padding/pad_last_dim.py | 3 + python/aitemplate/compiler/ops/pool/pool2d.py | 10 + .../compiler/ops/reduce/reduce_common.py | 92 +- python/aitemplate/compiler/ops/reduce/var.py | 8 + .../compiler/ops/reduce/vector_norm.py | 8 + .../compiler/ops/softmax/softmax.py | 30 +- .../aitemplate/compiler/ops/tensor/argmax.py | 13 +- .../compiler/ops/tensor/concatenate.py | 5 - .../aitemplate/compiler/ops/tensor/permute.py | 44 +- .../ops/tensor/slice_reshape_scatter.py | 13 +- .../compiler/ops/tensor/slice_scatter.py | 15 +- python/aitemplate/compiler/ops/tensor/topk.py | 25 +- .../ops/upsample/upsampling_common.py | 6 + .../ops/vision_ops/nms/batched_nms.py | 6 + .../ops/vision_ops/nms/efficient_nms.py | 20 +- .../compiler/ops/vision_ops/nms/nms.py | 20 +- .../roi_ops/multi_level_roi_align.py | 6 + .../ops/vision_ops/roi_ops/roi_ops.py | 17 + python/aitemplate/compiler/public/__init__.py | 4 + python/aitemplate/compiler/stable_set.py | 100 + python/aitemplate/compiler/tensor_accessor.py | 2 +- .../aitemplate/compiler/transform/__init__.py | 3 +- .../compiler/transform/apply_padding.py | 2 +- .../compiler/transform/constant_folding.py | 16 +- .../compiler/transform/fuse_group_ops.py | 2 +- .../transform/fuse_mm_elementwise_patterns.py | 16 + .../transform/fuse_mm_reshape_permute.py | 189 ++ .../aitemplate/compiler/transform/fuse_ops.py | 12 + .../transform/fuse_permute_bmm_and_gemm.py | 246 ++ .../compiler/transform/fuse_split.py | 4 +- .../compiler/transform/name_graph.py | 13 +- .../compiler/transform/optimize_graph.py | 9 +- .../aitemplate/compiler/transform/profile.py | 90 +- .../compiler/transform/profile_dynamic_dim.py | 35 +- .../transform/split_large_concat_ops.py | 124 + .../transform/transform_memory_ops.py | 16 +- .../transform_strided_op_and_view_op.py | 3 +- .../transform/transform_strided_ops.py | 14 +- .../compiler/transform/transform_utils.py | 6 +- python/aitemplate/frontend/nn/__init__.py | 5 +- python/aitemplate/frontend/nn/attention.py | 189 +- .../aitemplate/frontend/nn/conv2d/conv2d.py | 1 + .../frontend/nn/conv2d/conv2d_bias.py | 41 + .../nn/conv2d/conv2d_bias_add_hardswish.py | 21 + .../nn/conv2d/conv2d_bias_add_relu.py | 21 + .../nn/conv2d/conv2d_bias_few_channels.py | 7 +- .../nn/conv2d/conv2d_bias_hardswish.py | 2 + .../conv2d_bias_hardswish_few_channels.py | 7 +- .../frontend/nn/conv2d/conv2d_bias_relu.py | 2 + .../conv2d/conv2d_bias_relu_few_channels.py | 7 +- .../frontend/nn/conv2d/conv2d_bias_sigmoid.py | 2 + .../nn/conv2d/transposed_conv2d_bias.py | 49 +- .../nn/conv2d/transposed_conv2d_bias_relu.py | 2 + python/aitemplate/frontend/nn/dropout.py | 5 + python/aitemplate/frontend/nn/dual_gemm.py | 72 + python/aitemplate/frontend/nn/embedding.py | 11 + python/aitemplate/frontend/nn/identity.py | 2 + python/aitemplate/frontend/nn/linear.py | 37 + python/aitemplate/frontend/nn/padding.py | 2 + python/aitemplate/frontend/nn/pool2d.py | 50 + python/aitemplate/frontend/nn/roi_ops.py | 71 + python/aitemplate/frontend/nn/upsample.py | 21 + python/aitemplate/frontend/nn/view_ops.py | 25 + python/aitemplate/testing/detect_target.py | 10 +- python/aitemplate/utils/__init__.py | 1 + python/aitemplate/utils/alignment.py | 36 + python/aitemplate/utils/graph_utils.py | 20 +- python/aitemplate/utils/logger.py | 20 + python/aitemplate/utils/mk_ck_lib/__init__.py | 18 + .../aitemplate/utils/mk_ck_lib/generator.py | 2 - .../utils/mk_cutlass_lib/extra_enum.py | 10 +- .../utils/mk_cutlass_lib/extra_gemm_emit.py | 136 + .../utils/serialization/ait_program.py | 90 + .../utils/serialization/serdes_code.py | 393 +++ python/aitemplate/utils/torch_utils.py | 1 + python/aitemplate/utils/visualization/plot.py | 76 +- python/setup.py | 32 +- static/csrc/debug_utility.cpp | 80 + static/csrc/model_container.cpp | 16 +- static/csrc/model_interface.cpp | 83 +- static/include/debug_utility.h | 30 + static/include/model_container.h | 20 +- static/include/model_interface.h | 46 +- static/include/raii_wrapper.h | 29 +- tests/unittest/backend/test_cuda_graph.py | 79 + tests/unittest/backend/test_model_api.py | 44 +- tests/unittest/backend/test_profiler.py | 77 + .../unittest/benchmark/test_gemm_benchmark.py | 321 +++ .../compiler/test_constant_folding.py | 6 +- .../compiler/test_fuse_mm_elementwise.py | 26 + .../compiler/test_fuse_mm_reshape_permute.py | 125 + .../compiler/test_fuse_permute_gemm.py | 86 + ...st_fused_elementwise_complex_dependency.py | 86 + tests/unittest/compiler/test_group_fusions.py | 5 + .../test_pad_gemm_with_elementwise.py | 45 + .../compiler/test_split_large_concat.py | 462 ++++ .../compiler/test_strided_group_gemm.py | 5 +- .../compiler/test_strided_op_cat_pattern.py | 226 +- .../compiler/test_strided_reshape_cat.py | 5 +- tests/unittest/compiler/test_tensor.py | 54 + .../unittest/compiler/test_transform_utils.py | 3 +- tests/unittest/ops/test_activation.py | 93 +- tests/unittest/ops/test_argmax.py | 12 +- tests/unittest/ops/test_attention.py | 336 ++- tests/unittest/ops/test_bmm_add.py | 2 +- tests/unittest/ops/test_bmm_permute.py | 16 +- tests/unittest/ops/test_bmm_rcr_n1.py | 2 +- tests/unittest/ops/test_bmm_rrr_k1_tanh.py | 2 +- tests/unittest/ops/test_bmm_softmax.py | 3 + tests/unittest/ops/test_bmm_softmax_bmm.py | 23 +- tests/unittest/ops/test_concatenate_tanh.py | 10 +- tests/unittest/ops/test_conv.py | 10 +- tests/unittest/ops/test_conv2d_bias_add.py | 8 +- tests/unittest/ops/test_conv3d.py | 89 + tests/unittest/ops/test_conv_bias.py | 8 +- .../ops/test_conv_bias_act_few_channels.py | 16 +- .../ops/test_conv_bias_add_hardswish.py | 8 +- tests/unittest/ops/test_conv_bias_add_relu.py | 8 +- .../unittest/ops/test_conv_bias_hardswish.py | 8 +- tests/unittest/ops/test_conv_bias_relu.py | 8 +- tests/unittest/ops/test_conv_bias_sigmoid.py | 8 +- tests/unittest/ops/test_cross_attention.py | 133 + tests/unittest/ops/test_depthwise_conv3d.py | 123 + tests/unittest/ops/test_dual_gemm.py | 193 ++ tests/unittest/ops/test_dynamic_conv.py | 96 +- tests/unittest/ops/test_efficient_nms.py | 24 +- tests/unittest/ops/test_fpn_roi_align.py | 14 + tests/unittest/ops/test_fused_elementwise.py | 337 ++- tests/unittest/ops/test_gemm.py | 46 + tests/unittest/ops/test_gemm_bias.py | 3 +- .../unittest/ops/test_gemm_bias_broadcast.py | 2 + .../unittest/ops/test_gemm_bias_hardswish.py | 2 +- tests/unittest/ops/test_gemm_bias_permute.py | 26 +- tests/unittest/ops/test_gemm_bias_relu.py | 2 +- tests/unittest/ops/test_gemm_bias_sigmoid.py | 2 +- tests/unittest/ops/test_gemm_bias_softmax.py | 5 + tests/unittest/ops/test_gemm_bias_swish.py | 2 +- tests/unittest/ops/test_gemm_bias_tanh.py | 2 +- tests/unittest/ops/test_gemm_permute.py | 119 +- .../ops/test_gemm_rcr_bias_fast_gelu.py | 2 +- tests/unittest/ops/test_gemm_rcr_fast_gelu.py | 91 + tests/unittest/ops/test_gemm_rrr_small_nk.py | 2 +- tests/unittest/ops/test_gemm_softmax.py | 7 +- tests/unittest/ops/test_group_gemm_rcr.py | 4 +- .../unittest/ops/test_group_gemm_rcr_bias.py | 2 +- .../test_group_gemm_rcr_bias_activation.py | 2 +- .../ops/test_group_gemm_rcr_bias_cat.py | 2 +- tests/unittest/ops/test_group_gemm_rcr_cat.py | 2 +- tests/unittest/ops/test_groupnorm.py | 23 +- .../test_int_elementwise_dynamic_reshape.py | 114 + tests/unittest/ops/test_nms.py | 29 +- tests/unittest/ops/test_norm.py | 13 + tests/unittest/ops/test_pad_last_dim.py | 10 +- tests/unittest/ops/test_permute.py | 34 +- tests/unittest/ops/test_reduce.py | 85 +- tests/unittest/ops/test_size_getitem_ops.py | 11 +- tests/unittest/ops/test_topk.py | 33 +- tests/unittest/ops/test_transpose_conv2d.py | 8 +- .../ops/test_transpose_conv2d_bias.py | 8 +- tests/unittest/ops/test_var.py | 13 + tests/unittest/test_stable_set.py | 68 + tests/unittest/util/test_debug_utils.py | 138 + tests/unittest/util/test_serdes.py | 290 ++ 463 files changed, 19731 insertions(+), 3247 deletions(-) create mode 100644 python/aitemplate/backend/cuda/attention/mem_eff_attention.py create mode 100644 python/aitemplate/backend/cuda/conv3d/__init__.py create mode 100644 python/aitemplate/backend/cuda/conv3d/common.py create mode 100644 python/aitemplate/backend/cuda/conv3d/conv3d.py create mode 100644 python/aitemplate/backend/cuda/conv3d/depthwise_conv3d.py create mode 100644 python/aitemplate/backend/cuda/elementwise/int_elementwise.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_dual_gemm.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py create mode 100644 python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_silu.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/common_no_bias.py create mode 100644 python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_fast_gelu.py create mode 100644 python/aitemplate/backend/cuda/groupnorm/layer_norm.cuh create mode 100644 python/aitemplate/backend/cuda/tensor/permute.cuh create mode 100644 python/aitemplate/backend/cuda/tensor/permute.py create mode 100644 python/aitemplate/compiler/dtype.py create mode 100644 python/aitemplate/compiler/ops/attention/mem_eff_attention.py create mode 100644 python/aitemplate/compiler/ops/common/int_elementwise.py create mode 100644 python/aitemplate/compiler/ops/conv/conv3d.py create mode 100644 python/aitemplate/compiler/ops/conv/depthwise_conv3d.py create mode 100644 python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py create mode 100644 python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_silu.py create mode 100644 python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_fast_gelu.py create mode 100644 python/aitemplate/compiler/stable_set.py create mode 100644 python/aitemplate/compiler/transform/fuse_mm_reshape_permute.py create mode 100644 python/aitemplate/compiler/transform/fuse_permute_bmm_and_gemm.py create mode 100644 python/aitemplate/compiler/transform/split_large_concat_ops.py create mode 100644 python/aitemplate/frontend/nn/dual_gemm.py create mode 100644 python/aitemplate/utils/alignment.py create mode 100644 python/aitemplate/utils/mk_ck_lib/__init__.py create mode 100644 python/aitemplate/utils/serialization/ait_program.py create mode 100644 python/aitemplate/utils/serialization/serdes_code.py create mode 100644 static/csrc/debug_utility.cpp create mode 100644 static/include/debug_utility.h create mode 100644 tests/unittest/backend/test_cuda_graph.py create mode 100644 tests/unittest/backend/test_profiler.py create mode 100644 tests/unittest/benchmark/test_gemm_benchmark.py create mode 100644 tests/unittest/compiler/test_fuse_mm_reshape_permute.py create mode 100644 tests/unittest/compiler/test_fuse_permute_gemm.py create mode 100644 tests/unittest/compiler/test_split_large_concat.py create mode 100644 tests/unittest/compiler/test_tensor.py create mode 100644 tests/unittest/ops/test_conv3d.py create mode 100644 tests/unittest/ops/test_cross_attention.py create mode 100644 tests/unittest/ops/test_depthwise_conv3d.py create mode 100644 tests/unittest/ops/test_dual_gemm.py create mode 100644 tests/unittest/ops/test_gemm_rcr_fast_gelu.py create mode 100644 tests/unittest/ops/test_int_elementwise_dynamic_reshape.py create mode 100644 tests/unittest/test_stable_set.py create mode 100644 tests/unittest/util/test_debug_utils.py create mode 100644 tests/unittest/util/test_serdes.py diff --git a/.gitmodules b/.gitmodules index 2aeb63ba5..a82a39064 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass - url = https://github.com/NVIDIA/cutlass.git + url = https://github.com/AITemplate/cutlass.git [submodule "3rdparty/cub"] path = 3rdparty/cub url = https://github.com/NVIDIA/cub.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass index dadc881a9..f434be22a 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit dadc881a9606f95cba1b20acda03c9d07c286239 +Subproject commit f434be22a6270f9f000712286f92545ccca045b7 diff --git a/examples/03_bert/benchmark_mi250.sh b/examples/03_bert/benchmark_mi250.sh index dab4ae50c..4bacb3407 100644 --- a/examples/03_bert/benchmark_mi250.sh +++ b/examples/03_bert/benchmark_mi250.sh @@ -4,8 +4,8 @@ HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark_ait.py #1GCD -HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size $1 +HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" #2GCD -HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size $1 & -HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size $1 && fg +HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" & +HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size "$1" && fg diff --git a/examples/05_stable_diffusion/benchmark_pt.py b/examples/05_stable_diffusion/benchmark_pt.py index 13b8738cc..05c65e9bf 100644 --- a/examples/05_stable_diffusion/benchmark_pt.py +++ b/examples/05_stable_diffusion/benchmark_pt.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import click import torch diff --git a/examples/05_stable_diffusion/compile.py b/examples/05_stable_diffusion/compile.py index e91af3bb9..9c87f4155 100644 --- a/examples/05_stable_diffusion/compile.py +++ b/examples/05_stable_diffusion/compile.py @@ -321,7 +321,9 @@ def compile_vae( @click.option("--batch-size", default=1, help="batch size") @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") -def compile_diffusers(token, width, height, batch_size, use_fp16_acc=True, convert_conv_to_gemm=True): +def compile_diffusers( + token, width, height, batch_size, use_fp16_acc=True, convert_conv_to_gemm=True +): logging.getLogger().setLevel(logging.INFO) np.random.seed(0) torch.manual_seed(4896) @@ -344,7 +346,11 @@ def compile_diffusers(token, width, height, batch_size, use_fp16_acc=True, conve hh = height // 8 # CLIP - compile_clip(batch_size=batch_size, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) + compile_clip( + batch_size=batch_size, + use_fp16_acc=use_fp16_acc, + convert_conv_to_gemm=convert_conv_to_gemm, + ) # UNet compile_unet( batch_size=batch_size * 2, @@ -354,7 +360,13 @@ def compile_diffusers(token, width, height, batch_size, use_fp16_acc=True, conve convert_conv_to_gemm=convert_conv_to_gemm, ) # VAE - compile_vae(batch_size=batch_size, width=ww, height=hh, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm) + compile_vae( + batch_size=batch_size, + width=ww, + height=hh, + use_fp16_acc=use_fp16_acc, + convert_conv_to_gemm=convert_conv_to_gemm, + ) if __name__ == "__main__": diff --git a/examples/05_stable_diffusion/demo_img2img.py b/examples/05_stable_diffusion/demo_img2img.py index 844aac726..569a713ed 100644 --- a/examples/05_stable_diffusion/demo_img2img.py +++ b/examples/05_stable_diffusion/demo_img2img.py @@ -17,9 +17,9 @@ import click import requests import torch -from PIL import Image from aitemplate.testing.benchmark_pt import benchmark_torch_function +from PIL import Image from pipeline_stable_diffusion_img2img_ait import StableDiffusionImg2ImgAITPipeline diff --git a/examples/05_stable_diffusion/modeling/attention.py b/examples/05_stable_diffusion/modeling/attention.py index efabc3c0c..14993e6d9 100644 --- a/examples/05_stable_diffusion/modeling/attention.py +++ b/examples/05_stable_diffusion/modeling/attention.py @@ -69,6 +69,7 @@ def __init__( self.num_heads, qkv_bias=True, has_residual=True, + use_mem_eff=True, ) self.rescale_output_factor = rescale_output_factor diff --git a/examples/05_stable_diffusion/modeling/clip.py b/examples/05_stable_diffusion/modeling/clip.py index c66ecfb90..f9687d64a 100644 --- a/examples/05_stable_diffusion/modeling/clip.py +++ b/examples/05_stable_diffusion/modeling/clip.py @@ -85,14 +85,12 @@ def forward(self, x, context=None, mask=None, residual=None): ) if USE_CUDA: - q = q * self.scale - attn = ops.bmm_rcr()( - (ops.reshape()(q, [bs * nheads, -1, d])), - (ops.reshape()(k, [bs * nheads, -1, d])), + attn_op = ops.mem_eff_attention(causal=False) + out = attn_op( + (ops.reshape()(q, [bs, nheads, -1, d])), + (ops.reshape()(k, [bs, nheads, -1, d])), + (ops.reshape()(v, [bs, nheads, -1, d])), ) - attn = ops.softmax()(attn, -1) - v = ops.reshape()(v, [bs * nheads, -1, d]) - out = ops.bmm_rrr_permute((nheads,))(attn, v) else: OP = ops.bmm_softmax_bmm_permute(shape=(nheads,), scale=self.scale) out = OP( diff --git a/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py b/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py index d6c75ab05..251326b55 100644 --- a/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py +++ b/examples/05_stable_diffusion/pipeline_stable_diffusion_img2img_ait.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +# flakes8: noqa import inspect import os from typing import List, Optional, Union @@ -346,7 +348,7 @@ def __call__( if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[t_index] # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) latent_model_input = latent_model_input.to(self.unet.dtype) t = t.to(self.unet.dtype) diff --git a/python/aitemplate/__init__.py b/python/aitemplate/__init__.py index ed1d8a72e..9adca1347 100644 --- a/python/aitemplate/__init__.py +++ b/python/aitemplate/__init__.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging -import os import sys from . import backend, compiler, frontend, testing, utils @@ -25,18 +23,4 @@ __all__ = ["backend", "compiler", "frontend", "testing", "utils"] -root_logger = logging.getLogger(__name__) -info_handle = logging.StreamHandler() -formatter = logging.Formatter("%(asctime)s %(levelname)s <%(name)s> %(message)s") -info_handle.setFormatter(formatter) -root_logger.addHandler(info_handle) -root_logger.propagate = False - -DEFAULT_LOGLEVEL = logging.getLogger().level -log_level_str = os.environ.get("LOGLEVEL", None) -LOG_LEVEL = ( - getattr(logging, log_level_str.upper()) - if log_level_str is not None - else DEFAULT_LOGLEVEL -) -root_logger.setLevel(LOG_LEVEL) +root_logger = utils.logger.setup_logger(__name__) diff --git a/python/aitemplate/backend/backend_spec.py b/python/aitemplate/backend/backend_spec.py index 44daa1f3c..62fd07ade 100644 --- a/python/aitemplate/backend/backend_spec.py +++ b/python/aitemplate/backend/backend_spec.py @@ -26,8 +26,24 @@ from .target import Target -@dataclass class BackendSpec: + pass + + +@dataclass +class CPUBackendSpec(BackendSpec): + func_enum_to_func_name: Dict[FuncEnum, str] = field( + default_factory=lambda: { + FuncEnum.ADD: "+", + FuncEnum.SUB: "-", + FuncEnum.MUL: "*", + FuncEnum.DIV: "/", + } + ) + + +@dataclass +class GPUBackendSpec(BackendSpec): dtype_to_backend_fp16_dtype: Dict[str, str] = field( default_factory=lambda: { "float16": "half", @@ -70,7 +86,6 @@ class BackendSpec: "float", ] ) - func_enum_to_func_name: Dict[FuncEnum, Dict[str, str]] = field( default_factory=lambda: { FuncEnum.ADD: { @@ -174,6 +189,24 @@ class BackendSpec: "half": "hsilu", "float": "fsilu", }, + FuncEnum.POW: { + "half2": "h2pow", + "half": "hpow", + "float": "fpow", + }, + FuncEnum.GELU: { + "half": "hgelu", + "float": "fgelu", + }, + FuncEnum.FASTGELU: { + "half": "h_fast_gelu", + "float": "f_fast_gelu", + }, + FuncEnum.SOFTPLUS: { + "half2": "h2softplus", + "half": "hsoftplus", + "float": "fsoftplus", + }, } ) @@ -183,10 +216,10 @@ def get_backend_type( dtype: str, num_elements_to_backend_type_list: List[Tuple[int, str]], ) -> str: - if dtype != "float16": + if dtype not in ("float16", "float"): raise NotImplementedError("Unsupported dtype {}!".format(dtype)) - for num, backend_type in num_elements_to_backend_type_list: - if num_elements % num == 0: + for alignment, backend_type in num_elements_to_backend_type_list: + if num_elements % alignment == 0: return backend_type raise RuntimeError( "Failed to infer data type! num_elements: {}, num_elements_to_backend_type_list: {}".format( @@ -216,9 +249,12 @@ def get_fp16_dtype(self, dtype: str): def dtype_to_backend_type(self, dtype: str): return self.get_dtype_to_dtype(dtype, self.dtype_to_backend_dtype) + def dtype_to_lib_type(self, dtype: str): + raise NotImplementedError + @dataclass -class ROCMSpec(BackendSpec): +class ROCMSpec(GPUBackendSpec): backend_name = "rocm" index_type = "int64_t" prefix = "hip" @@ -250,7 +286,7 @@ def dtype_to_lib_type(self, dtype: str): @dataclass -class CUDASpec(BackendSpec): +class CUDASpec(GPUBackendSpec): backend_name = "cuda" index_type = "int64_t" prefix = "cuda" diff --git a/python/aitemplate/backend/builder.py b/python/aitemplate/backend/builder.py index 80699a79b..bd0b8c4eb 100644 --- a/python/aitemplate/backend/builder.py +++ b/python/aitemplate/backend/builder.py @@ -23,6 +23,8 @@ import os import pathlib import re +import shlex +import subprocess import typing from typing import Optional @@ -35,6 +37,30 @@ # pylint: disable=W0221,C0103 +def _run_make_cmds(cmds, timeout): + logger.debug(__name__, f"make {cmds=}") + proc = subprocess.Popen( + [" && ".join(cmds)], + shell=True, + env=os.environ.copy(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + try: + out, err = proc.communicate(timeout) + except subprocess.TimeoutExpired as e: + proc.kill() + out, err = proc.communicate() + raise e + finally: + if proc.returncode != 0: + # Let's always print out more info upon any failures. + logger_f = logger.info + else: + logger_f = logger.debug + logger_f(__name__, f"make stdout: {out.decode()}\nmake stderr: {err.decode()}") + + def process_task(task: Task) -> None: """This function extracts stdout and stderr from a finished task. If the task process return code is not 0, will mark the task as @@ -156,6 +182,8 @@ def __init__(self, n_jobs: int = -1, timeout: int = 180) -> None: if num_builder is not None: n_jobs = int(num_builder) self._runner = Runner(n_jobs, timeout) + self._n_jobs = n_jobs + self._timeout = timeout def build_objs( self, @@ -250,14 +278,17 @@ def gen_makefile(self, file_pairs, dll_name, workdir, test_name): %.obj : %.bin {{bfile_cmd}} -.PHONY: all +.PHONY: all clean clean_constants all: {{target}} {{target}}: $(obj_files) $(CC) -shared $(fPIC_flag) $(CFLAGS) -o $@ $(obj_files) clean: - rm -f *.obj test.so + rm -f *.obj {{target}} test.so + +clean_constants: + rm -f constants.bin """ ) @@ -293,3 +324,70 @@ def gen_makefile(self, file_pairs, dll_name, workdir, test_name): with open(dumpfile, "w+") as f: # fix the makefile indentation f.write(re.sub("^ ", "\t", makefile_str, flags=re.M)) + + def _gen_makefile_for_profilers(self, file_pairs, profiler_dir): + makefile_template = jinja2.Template( + """ +programs = {{programs}} +all: $(programs) +.PHONY: all clean + +$(programs): %: %.{{cpp}} + {{cc_cmd}} + +clean: + rm -f $(programs) +""" + ) + program_relative_paths = sorted( + {f[1].split(os.path.join(profiler_dir, ""))[-1] for f in file_pairs} + ) + logger.info(__name__, f"compiling {len(program_relative_paths)} profiler srcs") + programs = " ".join(program_relative_paths) + cc_cmd = Target.current().compile_cmd(True).format(target="$@", src="$<") + makefile_str = makefile_template.render( + cpp="cu", + programs=programs, + cc_cmd=cc_cmd, + ) + + dumpfile = os.path.join(profiler_dir, "Makefile") + with open(dumpfile, "w+") as f: + # fix the makefile indentation + f.write(re.sub("^ ", "\t", makefile_str, flags=re.M)) + + def make_profilers(self, generated_profilers, workdir): + file_pairs = [f for gp in generated_profilers for f in gp] + if not file_pairs: + return + build_dir = shlex.quote(os.path.join(workdir, "profiler")) + self._gen_makefile_for_profilers(file_pairs, build_dir) + make_path = shlex.quote(Target.current().make()) + make_flags = " ".join( + [ + "--output-sync", + f"-C {build_dir}", + ] + ) + make_clean_cmd = f" {make_path} {make_flags} clean " + make_all_cmd = f" {make_path} {make_flags} -j{self._n_jobs} all " + cmds = [make_clean_cmd, make_all_cmd] + _run_make_cmds(cmds, self._timeout) + + def make(self, file_pairs, dll_name, workdir, test_name): + self.gen_makefile(file_pairs, dll_name, workdir, test_name) + make_path = shlex.quote(Target.current().make()) + build_dir = shlex.quote(os.path.join(workdir, test_name)) + make_flags = " ".join( + [ + "--output-sync", + f"-C {build_dir}", + ] + ) + make_clean_cmd = f" {make_path} {make_flags} clean " + make_all_cmd = f" {make_path} {make_flags} -j{self._n_jobs} all " + make_clean_constants_cmd = f" {make_path} {make_flags} clean_constants " + cmds = [make_clean_cmd, make_all_cmd] + if not logger.is_debug(): + cmds.append(make_clean_constants_cmd) + _run_make_cmds(cmds, self._timeout) diff --git a/python/aitemplate/backend/codegen.py b/python/aitemplate/backend/codegen.py index fcd806882..6ad72b854 100644 --- a/python/aitemplate/backend/codegen.py +++ b/python/aitemplate/backend/codegen.py @@ -28,11 +28,13 @@ from aitemplate.backend.main_templates import MODEL_CONTAINER_TEMPLATE, MODEL_TEMPLATE from aitemplate.compiler.base import Operator +from aitemplate.compiler.dtype import dtype_to_enumerator, get_dtype_size from aitemplate.compiler.tensor_accessor import TensorAccessor from aitemplate.compiler.transform.memory_planning import Workspace +from aitemplate.utils import logger -from ..compiler.base import get_dtype_size, IntImm, IntVar, Tensor +from ..compiler.base import IntImm, IntVar, IntVarTensor, Tensor from . import registry from .target import Target @@ -44,6 +46,7 @@ "int": "int32_t*", "int32": "int32_t*", "int64": "int64_t*", + "bool": "bool*", } @@ -61,10 +64,12 @@ def gen_profiler(sorted_graph: list[Tensor], workdir: str, dynamic_profiling_str Pass-through to gen_profiler kernels of nodes in the graph. See also: :func:`~aitemplate.compiler.transform.profile.profile` """ + results = [] for node in sorted_graph: for func in node.src_ops(): if "has_profiler" in func._attrs and func._attrs["has_profiler"]: - func.gen_profiler(workdir, dynamic_profiling_strategy) + results.append(func.gen_profiler(workdir, dynamic_profiling_strategy)) + return results def gen_function_src( @@ -100,6 +105,7 @@ def gen_function_src( with open(src_path, "w") as fo: fo.write(func.gen_function()) exist_func.add(fname) + logger.info(__name__, f"generated {len(file_pairs)} function srcs") return file_pairs @@ -171,22 +177,6 @@ def set_value_from_map(map_name: Any, var_name: Any, indent: str = " ") -> st return f'{indent}{value} = static_cast({map_name}["{key}"]);' -def dtype_to_enumerator(dtype): - def _impl(dtype): - if dtype == "float16": - return "kHalf" - elif dtype == "float32" or dtype == "float": - return "kFloat" - elif dtype == "int32" or dtype == "int": - return "kInt" - elif dtype == "int64": - return "kLong" - else: - raise AssertionError(f"unknown dtype {dtype}") - - return f"AITemplateDtype::{_impl(dtype)}" - - def count_inputs_outputs(graph): n_inputs = n_outputs = 0 for node in graph: @@ -217,7 +207,7 @@ def check_not_null( if tensor_idx is None: check = name else: - check = f"params[{tensor_idx}].ptr" + check = f"params_[{tensor_idx}].ptr" shape = ["1"] lower_bound_is_zero = False @@ -249,7 +239,7 @@ def check_not_null( def device_copy(dst_tensor: Tensor, src_tensor: Tensor, dst_idx: int) -> str: src_name = src_tensor._attrs["name"] - dst_ptr = f"params[{dst_idx}].ptr" + dst_ptr = f"params_[{dst_idx}].ptr" shape = ["1"] for dim in dst_tensor._attrs["shape"]: if isinstance(dim, IntImm): @@ -271,10 +261,12 @@ def __init__( num_outputs: int, constants_data_file: io.BytesIO, output_name_to_idx: Dict[str, int], + check_all_nan_and_inf: bool = False, + check_all_outputs: bool = False, ): self.target = Target.current() self.f_var_decl = registry.get(self.target.name() + ".lib.var_decl") - self.f_ptr_decl = registry.get(self.target.name() + ".lib.ptr_decl") + self.f_ptr_decl = registry.get(self.target.name() + ".lib.void_ptr_decl") self.constants_data_file = constants_data_file @@ -321,6 +313,12 @@ def __init__( num_outputs, ) + self.check_all_nan_and_inf = check_all_nan_and_inf + self.check_all_outputs = check_all_outputs + + # This records whether or not we should debug header. + self.debug_header = False + def _tensor_slice_func( self, node: Tensor, @@ -351,7 +349,7 @@ def max_value(var_or_imm): for dim in tensor._attrs["shape"] ) self.set_up_param_dynamic_shapes.append( - set_value(f"params[{idx}].shape_ptrs", f"{{{param_shape_init}}}") + set_value(f"params_[{idx}].shape_ptrs", f"{{{param_shape_init}}}") ) name = tensor._attrs["name"] self.set_up_param_names.append(set_value(f"param_names_[{idx}]", f'"{name}"')) @@ -384,7 +382,7 @@ def _codegen_param_setup( self.owned_constants_init.append(constant_info) self.constants_data_size += num_bytes self.num_constants += 1 - else: + elif not isinstance(tensor, IntVarTensor): # Unbound constant. We will expect the user to set this via SetConstant. self.set_up_constant_names.append( set_value( @@ -393,7 +391,8 @@ def _codegen_param_setup( ) ) self._record_param_tensor_info( - tensor, self.unbound_constant_idx + self.num_inputs + self.num_outputs + tensor, + self.unbound_constant_idx + self.num_inputs + self.num_outputs, ) self.unbound_constant_idx += 1 self.set_inputs.append(check_not_null(tensor)) @@ -413,7 +412,7 @@ def _codegen_input_tensor(self, tensor: Tensor) -> None: self.set_inputs.append( set_value( name, - f"static_cast(params[{self.input_idx}].ptr)", + f"static_cast(params_[{self.input_idx}].ptr)", ) ) self.set_inputs.append(check_not_null(tensor)) @@ -444,7 +443,7 @@ def _codegen_output_aliases_tensor(self, tensor: Tensor) -> None: self.set_inputs.append( set_value( name, - f"static_cast(params[{ptr_idx}].ptr)", + f"static_cast(params_[{ptr_idx}].ptr)", ) ) @@ -488,7 +487,7 @@ def _codegen_output_tensor(self, tensor: Tensor) -> None: self.set_inputs.append( set_value( name, - f"static_cast(params[{self.input_idx}].ptr)", + f"static_cast(params_[{self.input_idx}].ptr)", ) ) self._record_param_tensor_info(tensor, self.input_idx) @@ -525,6 +524,9 @@ def _process_dims_for_op(self, node: Operator) -> None: def _process_src_ops(self, node: Tensor) -> None: funcs = node.src_ops() + if len(funcs) == 0: + return + for func in funcs: f_func_decl = registry.get( ".".join((self.target.name(), func._attrs["op"], "func_decl")) @@ -550,12 +552,36 @@ def _process_src_ops(self, node: Tensor) -> None: self.state_record.add(func._attrs["name"]) self._process_dims_for_op(func) + if self.check_all_nan_and_inf or node._attrs.get("check_nan_and_inf", False): + self._append_check_nan_and_inf(node) + if self.check_all_outputs or node._attrs.get("check_outputs", False): + self._append_check_outputs(node) + + def _append_check_nan_and_inf(self, node: Tensor): + self.debug_header = True + tensor_name = node._attrs["name"] + elem_cnt = "*".join([shape.pseudo_code() for shape in node.shape()]) + self.func_seq.append( + f' InvokeInfAndNanChecker(reinterpret_cast({tensor_name}), "{tensor_name}", {elem_cnt}, stream);\n' + ) + + def _append_check_outputs(self, node: Tensor): + self.debug_header = True + tensor_name = node._attrs["name"] + elem_cnt = "*".join([shape.pseudo_code() for shape in node.shape()]) + self.func_seq.append( + f' InvokeOutputsChecker(reinterpret_cast({tensor_name}), "{tensor_name}", {elem_cnt}, stream);\n' + ) + def append_tensor(self, node: Tensor) -> None: if node._attrs["nop"]: return name = node._attrs["name"] dtype = node._attrs["dtype"] - self.tensor_decl.append(self.f_ptr_decl(name=name, dtype=dtype)) + if isinstance(node, IntVarTensor): + self.tensor_decl.append(self.f_var_decl(name=name)) + else: + self.tensor_decl.append(self.f_ptr_decl(name=name, dtype=dtype)) is_param = node._attrs["is_param"] is_output = node._attrs["is_output"] @@ -576,14 +602,14 @@ def append_tensor(self, node: Tensor) -> None: elif has_output_aliases: # Special case: internal tensor that aliases an output. self._codegen_output_aliases_tensor(node) - elif not is_view: + elif not is_view and not isinstance(node, IntVarTensor): # Normal, internal tensor that is not a view: point it to the # internal blob of memory assert ( node._attrs["offset"] >= 0 ), f"Non-parameter node '{name}' must have non-negative offset" self.tensor_slice.append(self._tensor_slice_func(node, "blob_ptr")) - else: + elif not isinstance(node, IntVarTensor): # Normal view, point it to the same memory as whatever it # aliases self.set_inputs.append(set_value(name, view._attrs["name"])) @@ -621,6 +647,7 @@ def generate_source(self) -> Dict[str, str]: function_state="\n".join(self.function_state), target_has_graph_mode=target_has_graph_mode, unique_workspace_size=self.workspace.unique_size, + debug_header=self.debug_header, ) result["model-generated.h"] = model_def @@ -678,6 +705,8 @@ def gen_library_src( # noqa: C901 workdir: str, output_tensors: List[Tensor], model_name: str = "", + check_all_nan_and_inf: bool = False, + check_all_outputs: bool = False, ) -> list[Tuple[str, str]]: """Generate model driver source code files for the given graph @@ -722,6 +751,8 @@ def to_obj_name(name: str): num_outputs, constants_data_file, output_name_to_index, + check_all_nan_and_inf, + check_all_outputs, ) for node in sorted_graph: model_container_generator.append_tensor(node) @@ -741,4 +772,5 @@ def to_obj_name(name: str): for fname in sources: to_build.append((fname, to_obj_name(fname))) + logger.info(__name__, f"generated {len(to_build)} library srcs") return to_build diff --git a/python/aitemplate/backend/common/concatenate_common.py b/python/aitemplate/backend/common/concatenate_common.py index 99f24bb03..001afe0ac 100644 --- a/python/aitemplate/backend/common/concatenate_common.py +++ b/python/aitemplate/backend/common/concatenate_common.py @@ -22,9 +22,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{elem_output_type}} * /*output*/, + void * /*output*/, {{index_type}} *[] /*output_shape*/, - const {{elem_input_type}} *[] /*inputs*/, + const void *[] /*inputs*/, const {{index_type}} *[], /* real_input_shapes, representing shapes of those inputs whose masks are False, i.e. inputs that will be copied to the output tensor by concat.*/ @@ -161,7 +161,7 @@ constexpr unsigned read_t_sz = sizeof(READ_T); constexpr unsigned elem_t_sz = sizeof(ELEM_T); - assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); + static_assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); constexpr INDEX_T n_of_elem_t = read_t_sz / elem_t_sz; // number of READ_T elements per thread INDEX_T reads_per_thread_in_read_t = ElemsPerThread / n_of_elem_t; @@ -225,9 +225,9 @@ template void concatenate_kernel_launcher( - ELEM_T *output, + void *output, const {{index_type}} *output_shape, - const ELEM_T *inputs[], + const void *inputs[], const {{index_type}} *real_input_shapes[], const TensorAccessor *input_accessors[], const int64_t concat_dim_offsets[], @@ -248,7 +248,7 @@ INDEX_T max_num_input_elems = 0; for (INDEX_T i = 0; i < NumInputs; i++) { INDEX_T num_elems = get_num_elems(real_input_shapes[i], Rank); - input_meta.inputs[i] = inputs[i]; + input_meta.inputs[i] = static_cast(inputs[i]); input_meta.input_accessors[i] = *(input_accessors[i]); input_meta.concat_dim_offsets[i] = concat_dim_offsets[i]; input_meta.concat_dim_values[i] = real_input_shapes[i][concat_dim]; @@ -272,7 +272,7 @@ } \\ concatenate_kernel \\ <<>>( \\ - output, \\ + static_cast(output), \\ output_meta, \\ input_meta, \\ concat_dim, \\ @@ -309,9 +309,9 @@ {{header_src}} void {{func_name}}( - {{elem_output_type}} *output, + void *output, {{index_type}} *output_shape[], - const {{elem_input_type}} *inputs[], + const void *inputs[], const {{index_type}} *real_input_shapes[], const {{index_type}} *all_input_shapes[], const bool input_masks[], @@ -322,6 +322,7 @@ {{index_type}} num_all_inputs, {{prefix}}Stream_t stream ) { + // DO NOTHING } """ ) @@ -406,9 +407,9 @@ {{kernel_src}} void {{func_name}}( - {{elem_output_type}} *output, + void *output, {{index_type}} *output_shape[], - const {{elem_input_type}} *inputs[], + const void *inputs[], const {{index_type}} *real_input_shapes[], /* real_input_shapes, representing shapes of those inputs whose masks are False, i.e. inputs that will be copied to the output @@ -520,7 +521,7 @@ """ {{indent}}{ -{{indent}} const {{input_elem_type}} *inputs[] = { +{{indent}} const void *inputs[] = { {{indent}} {{inputs}} {{indent}} }; @@ -579,16 +580,8 @@ def gen_function_decl(func_attrs, backend_spec): str Rendered function declaration. """ - # get dtype from orig_x in case actual "inputs" is turned into empty - # by some transformation - orig_x = func_attrs["original_inputs"][0] - y = func_attrs["outputs"][0] - input_type = backend_spec.dtype_to_backend_type(orig_x._attrs["dtype"]) - output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) return FUNC_DECL_TEMPLATE.render( func_name=func_attrs["name"], - elem_output_type=output_type, - elem_input_type=input_type, index_type=backend_spec.index_type, prefix=backend_spec.prefix, ) @@ -691,8 +684,6 @@ def _stride(shape, dim): return SRC_TEMPLATE.render( kernel_src=kernel_src, func_name=func_attrs["name"], - elem_input_type=input_type, - elem_output_type=output_type, exec_paths=exec_paths, index_type=backend_spec.index_type, prefix=backend_spec.prefix, @@ -700,8 +691,6 @@ def _stride(shape, dim): return DUMMY_KERNEL_TEMPLATE.render( func_name=func_attrs["name"], - elem_input_type=input_type, - elem_output_type=output_type, header_src=header_src, index_type=backend_spec.index_type, prefix=backend_spec.prefix, @@ -719,14 +708,8 @@ def gen_function_call( ---------- func_attrs : Dict[str, Any] Stores the operation attributes. - index_type: str - Index type. - cast_to_const_half_ptr_template: jinja template - Cast to const half ptr template. - cast_to_half_ptr_template: jinja template - Cast to half ptr template. - dtype_to_backend_type: Dict[str, str] - Stores python dtype to backend (rocm, cuda) type. + backend_spec : BackendSpec + CUDA / RocM type definitions indent : str, optional Indent for template, by default " ". @@ -746,12 +729,7 @@ def gen_function_call( y = func_attrs["outputs"][0] concat_dim = func_attrs["concat_dim"] - input_names = ",\n ".join( - [ - backend_spec.cast_to_const_half_ptr_template.render(name=i._attrs["name"]) - for i in inputs - ] - ) + input_names = ",\n ".join([i._attrs["name"] for i in inputs]) real_input_shape_defs = [] real_input_shape_names = [] for idx, (i, input_accessor) in enumerate(zip(inputs, input_accessors)): @@ -769,7 +747,6 @@ def gen_function_call( y_shape = y._attrs["shape"] y_dim_refs = ", ".join(["&" + dim._attrs["name"] for dim in y_shape]) - casted_y_ptr = backend_spec.cast_to_half_ptr_template.render(name=y._attrs["name"]) input_masks = func_attrs["input_masks"] input_indices = [idx for idx, m in enumerate(input_masks) if m is True] @@ -819,7 +796,6 @@ def gen_function_call( return FUNC_CALL_TEMPLATE.render( indent=indent, - input_elem_type=backend_spec.dtype_to_backend_type(orig_x._attrs["dtype"]), inputs=input_names, real_input_shape_defs="".join(real_input_shape_defs), real_input_shapes=", ".join(real_input_shape_names), @@ -830,7 +806,7 @@ def gen_function_call( output_dim_refs=y_dim_refs, func_name=func_attrs["name"], output=y._attrs["name"], - output_ptr=casted_y_ptr, + output_ptr=y._attrs["name"], concat_dim=concat_dim, rank=len(orig_x._attrs["shape"]), num_real_inputs=len(inputs), diff --git a/python/aitemplate/backend/common/elementwise_common.py b/python/aitemplate/backend/common/elementwise_common.py index 14872058a..546763505 100644 --- a/python/aitemplate/backend/common/elementwise_common.py +++ b/python/aitemplate/backend/common/elementwise_common.py @@ -20,11 +20,11 @@ from typing import Any, Dict, List, Tuple import jinja2 +from aitemplate.backend.backend_spec import BackendSpec from ...compiler.base import IntImm, IntVar, Operator, Tensor from ...compiler.tensor_accessor import TensorAccessor from ...utils import shape_utils -from ..backend_spec import BackendSpec from . import tensor_accessor_codegen CONSTANT_TEMPLATE = jinja2.Template( @@ -96,11 +96,11 @@ KERNEL_TEMPLATE = jinja2.Template( """ __global__ void -{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} int n_elements) { +{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{index_type}} n_elements) { const int bid = blockIdx.x; const int tid = threadIdx.x; - const int idx = bid * FUSED_ELE_THREAD_SIZE + tid; - const int idx_elem = idx * N_ELEMENTS_PER_THREAD; + const {{index_type}} idx = bid * FUSED_ELE_THREAD_SIZE + tid; + const {{index_type}} idx_elem = idx * N_ELEMENTS_PER_THREAD; if (idx_elem >= n_elements) { return; } @@ -115,8 +115,8 @@ """ ) -FUNC_DECL_INPUT_PARAM_TEMPLATE = jinja2.Template("const {{data_t}}* input{{idx}}") -FUNC_DECL_OUTPUT_PARAM_TEMPLATE = jinja2.Template("{{data_t}}* output{{idx}}") +FUNC_DECL_INPUT_PARAM_TEMPLATE = jinja2.Template("const void* input{{idx}}") +FUNC_DECL_OUTPUT_PARAM_TEMPLATE = jinja2.Template("void* output{{idx}}") KERNEL_CALL_INPUT_PARAM_TEMPLATE = jinja2.Template( "reinterpret_cast(input{{idx}})" ) @@ -140,7 +140,7 @@ } // namespace -void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims_decl}} int n_elements, {{prefix}}Stream_t stream) { +void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims_decl}} {{index_type}} n_elements, {{prefix}}Stream_t stream) { if (n_elements == 0) { return; } @@ -157,14 +157,14 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ -void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} int n_elements, {{prefix}}Stream_t stream); +void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{index_type}} n_elements, {{prefix}}Stream_t stream); """ ) FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{ - {{indent}}int {{func_name}}_n_elements = {{calculate_n}}; + {{indent}}{{index_type}} {{func_name}}_n_elements = {{calculate_n}}; {{indent}}invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{func_name}}_n_elements, {{stream}}); {{indent}}} """ @@ -357,14 +357,10 @@ def _get_types_and_sizes( # Handle input broadcast. output_shape = output_accessors[0].original_shapes - dtype = "float16" + dtype = inputs[0]._attrs["dtype"] input_broadcast_sizes = [] min_num_elements = None - for input_tensor, input_accessor in zip(inputs, input_accessors): - if input_tensor._attrs["dtype"] != "float16": - raise NotImplementedError( - "Unsupported dtype {}!".format(input_tensor._attrs["dtype"]) - ) + for input_accessor in input_accessors: input_shape = input_accessor.original_shapes broadcastable, _ = shape_utils.get_broadcast_max_shape( output_shape, input_shape @@ -433,7 +429,7 @@ def _parse_func_metadata( op_type = backend_spec.get_backend_type( alignment, dtype, backend_spec.op_num_elements_to_backend_type ) - data_type = backend_spec.get_fp16_dtype(dtype) + data_type = backend_spec.dtype_to_backend_type(dtype) sub_func_metadata, op_type = _get_sub_func_metadata( ops, data_type, op_type, backend_spec ) @@ -645,6 +641,7 @@ def _gen_kernel_function( kernel_func = KERNEL_TEMPLATE.render( func_name=func_attrs["name"], + index_type=index_type, output_params=output_params_decl, input_params=input_params_decl, dynamic_dims=_gen_dynamic_dim_str( @@ -699,17 +696,13 @@ def fused_elementwise_gen_function( ) output_params_decl = ",".join( [ - FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render( - data_t=fused_elementwise_metadata.data_t, idx=i - ) + FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render(idx=i) for i, _ in enumerate(fused_elementwise_metadata.outputs) ] ) input_params_decl = ",".join( [ - FUNC_DECL_INPUT_PARAM_TEMPLATE.render( - data_t=fused_elementwise_metadata.data_t, idx=i - ) + FUNC_DECL_INPUT_PARAM_TEMPLATE.render(idx=i) for i, _ in enumerate(fused_elementwise_metadata.inputs) ] ) @@ -737,6 +730,7 @@ def fused_elementwise_gen_function( function = FUNC_TEMPLATE.render( prefix=backend_spec.prefix, + index_type=backend_spec.index_type, head=backend_spec.header_src_template.render(extra_header=head_template), constant=constant, custom_libs=custom_libs, @@ -787,23 +781,20 @@ def fused_elementwise_gen_function_decl( ) output_params_decl = ",".join( [ - FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render( - data_t=fused_elementwise_metadata.data_t, idx=i - ) + FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render(idx=i) for i, _ in enumerate(fused_elementwise_metadata.outputs) ] ) input_params_decl = ",".join( [ - FUNC_DECL_INPUT_PARAM_TEMPLATE.render( - data_t=fused_elementwise_metadata.data_t, idx=i - ) + FUNC_DECL_INPUT_PARAM_TEMPLATE.render(idx=i) for i, _ in enumerate(fused_elementwise_metadata.inputs) ] ) function_decl = FUNC_DECL_TEMPLATE.render( prefix=backend_spec.prefix, + index_type=backend_spec.index_type, func_name=func_name, output_params=output_params_decl, input_params=input_params_decl, @@ -840,27 +831,9 @@ def fused_elementwise_gen_function_call( backend_spec, ) - output_params_vec = [] - for output in outputs: - if output._attrs["dtype"] != "float16": - raise NotImplementedError( - "Unsupported dtype {}".format(output._attrs["dtype"]) - ) - output_params_vec.append( - backend_spec.cast_to_half_ptr_template.render(name=output._attrs["name"]) - ) - output_params = ",".join(output_params_vec) + output_params = ",".join([output._attrs["name"] for output in outputs]) - input_params_vec = [] - for inp in inputs: - if inp._attrs["dtype"] != "float16": - raise NotImplementedError( - "Unsupported dtype {}".format(inp._attrs["dtype"]) - ) - input_params_vec.append( - backend_spec.cast_to_half_ptr_template.render(name=inp._attrs["name"]) - ) - input_params = ",".join(input_params_vec) + input_params = ",".join([input._attrs["name"] for input in inputs]) num_elements_calculator = _gen_int_var_product_str( output_accessors[0].original_shapes @@ -869,6 +842,7 @@ def fused_elementwise_gen_function_call( return FUNC_CALL_TEMPLATE.render( stream=backend_spec.stream, func_name=func_attrs["name"], + index_type=backend_spec.index_type, calculate_n=num_elements_calculator, output_params=output_params, input_params=input_params, diff --git a/python/aitemplate/backend/common/split_common.py b/python/aitemplate/backend/common/split_common.py index 9205c90ee..a1dbaa930 100644 --- a/python/aitemplate/backend/common/split_common.py +++ b/python/aitemplate/backend/common/split_common.py @@ -20,9 +20,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{elem_output_type}} *[] /*outputs*/, + void *[] /*outputs*/, {{index_type}} **[] /*output_shapes*/, - const {{elem_input_type}} * /*input*/, + const void * /*input*/, const {{index_type}} * /*input_shape*/, {{index_type}} /*num_splits*/, {{index_type}} [] /*split_sizes*/, @@ -127,9 +127,9 @@ int64_t split_dim_size = output_meta.split_dim_sizes[blockIdx.y]; int64_t input_offset = output_offset * input_split_dim_stride; - unsigned read_t_sz = sizeof(READ_T); - unsigned elem_t_sz = sizeof(ELEM_T); - assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); + unsigned constexpr read_t_sz = sizeof(READ_T); + unsigned constexpr elem_t_sz = sizeof(ELEM_T); + static_assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); {{index_type}} n_of_elem_t = read_t_sz / elem_t_sz; // number of READ_T elements per thread {{index_type}} reads_per_thread_in_read_t = ElemsPerThread / n_of_elem_t; @@ -196,9 +196,9 @@ template void split_kernel_launcher( - ELEM_T *outputs[], + void *outputs[], {{index_type}} *output_shapes[], - const ELEM_T *input, + const void *input, const {{index_type}} *input_shape, const {{index_type}} split_dim, {{prefix}}Stream_t stream @@ -217,7 +217,7 @@ {{index_type}} offset = 0; LoadVecType min_vec_type = LoadVecType::VT_FLOAT4; for ({{index_type}} i = 0; i < NumSplits; i++) { - output_meta.outputs[i] = outputs[i]; + output_meta.outputs[i] = static_cast(outputs[i]); output_meta.split_dim_offsets[i] = offset; output_meta.split_dim_sizes[i] = output_shapes[i][split_dim]; output_meta.num_elems[i] = get_num_elems(output_shapes[i], Rank); @@ -246,7 +246,7 @@ } \\ split_kernel \\ <<>>( \\ - input, \\ + static_cast(input), \\ input_meta, \\ output_meta, \\ split_dim, \\ @@ -309,9 +309,9 @@ """ {{kernel_src}} void {{func_name}}( - {{elem_output_type}}* outputs[], + void* outputs[], {{index_type}} **output_shapes[], - const {{elem_input_type}}* input, + const void* input, const {{index_type}} *input_shape, {{index_type}} num_splits, {{index_type}} split_sizes[], @@ -390,7 +390,7 @@ """ {{indent}}{ -{{indent}} {{output_elem_type}} *outputs[] = { +{{indent}} void *outputs[] = { {{indent}} {{outputs}} {{indent}} }; @@ -431,21 +431,17 @@ def gen_function_decl(func_attrs, backend_spec): ---------- func_attrs : Dict[str, Any] Stores the operation attributes. + backend_spec : BackendSpec + Cuda/Rocm type definitions Returns ------- str Rendered function declaration. """ - x = func_attrs["inputs"][0] - y = func_attrs["outputs"][0] - input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) - output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) return FUNC_DECL_TEMPLATE.render( index_type=backend_spec.index_type, prefix=backend_spec.prefix, func_name=func_attrs["name"], - elem_output_type=output_type, - elem_input_type=input_type, ) @@ -470,6 +466,9 @@ def gen_function(func_attrs, backend_spec): input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) + if input_type != output_type: + raise NotImplementedError("input type must equal to output type") + # TODO: consider to add profiling paths for tuning # elems_per_thread and threads_per_block exec_paths = EXEC_COND_TEMPLATE.render( @@ -490,8 +489,6 @@ def gen_function(func_attrs, backend_spec): return SRC_TEMPLATE.render( kernel_src=kernel_src, func_name=func_attrs["name"], - elem_input_type=input_type, - elem_output_type=output_type, exec_paths=exec_paths, index_type=backend_spec.index_type, prefix=backend_spec.prefix, @@ -515,16 +512,10 @@ def gen_function_call(func_attrs, backend_spec, indent=" "): """ x = func_attrs["inputs"][0] outputs = func_attrs["outputs"] - y = outputs[0] split_dim = func_attrs["split_dim"] num_splits = len(func_attrs["split_sizes"]) - output_names = ",\n ".join( - [ - backend_spec.cast_to_half_ptr_template.render(name=i._attrs["name"]) - for i in outputs - ] - ) + output_names = ",\n ".join([i._attrs["name"] for i in outputs]) output_shape_defs = [] output_shape_names = [] @@ -545,22 +536,18 @@ def gen_function_call(func_attrs, backend_spec, indent=" "): x_shape = x._attrs["shape"] x_dims = ", ".join([dim._attrs["name"] for dim in x_shape]) - casted_x_ptr = backend_spec.cast_to_const_half_ptr_template.render( - name=x._attrs["name"] - ) split_sizes = ", ".join([str(i) for i in func_attrs["split_sizes"]]) return FUNC_CALL_TEMPLATE.render( indent=indent, - output_elem_type=backend_spec.dtype_to_backend_type(y._attrs["dtype"]), outputs=output_names, output_shape_defs="".join(output_shape_defs), output_shapes=", ".join(output_shape_names), input_dims=x_dims, func_name=func_attrs["name"], input_name=x._attrs["name"], - input_ptr=casted_x_ptr, + input_ptr=x._attrs["name"], split_dim=split_dim, rank=len(x._attrs["shape"]), num_splits=num_splits, diff --git a/python/aitemplate/backend/common/tensor/argmax_common.py b/python/aitemplate/backend/common/tensor/argmax_common.py index bb422646e..67c3d4b94 100644 --- a/python/aitemplate/backend/common/tensor/argmax_common.py +++ b/python/aitemplate/backend/common/tensor/argmax_common.py @@ -21,9 +21,6 @@ import jinja2 -from ... import builder -from ...target import Target - # pylint: disable=C0301 FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( @@ -285,7 +282,7 @@ class MultiplyFunctor final { {{indent}} {{elem_cnt}}, {{indent}} {{instance_size}}, {{indent}} {{instance_num}}, -{{indent}} global_workspace, stream /* default stream */ +{{indent}} global_workspace_, stream /* default stream */ {{indent}}); """ ) @@ -450,7 +447,4 @@ def gen_profiler( ) op_name = func_attrs["op"] add_profiler(file_pairs, workdir, op_type, op_name, code) - # build - target = Target.current() - compile_engine = builder.Builder() - compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) + return file_pairs diff --git a/python/aitemplate/backend/common/tensor/batch_gather_common.py b/python/aitemplate/backend/common/tensor/batch_gather_common.py index 86bbea7a0..97e8aee77 100644 --- a/python/aitemplate/backend/common/tensor/batch_gather_common.py +++ b/python/aitemplate/backend/common/tensor/batch_gather_common.py @@ -22,11 +22,6 @@ # pylint: disable=C0301 -FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( - """reinterpret_cast( - {% if is_cuda %}&({% endif %}{{name}}{% if is_cuda %}->raw()){% endif %})""" -) - FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") FUNC_TEMPLATE = jinja2.Template( @@ -41,15 +36,15 @@ {{func_signature}} { - batch_gather_launcher(stream, batch_num, indices_num, instance_size, gather_dim_size, input, indices, workspace, output); + batch_gather_launcher<{{dtype}}, int64_t>(stream, batch_num, indices_num, instance_size, gather_dim_size, static_cast(input), indices, workspace, static_cast<{{dtype}}*>(output)); } """ ) FUNC_SIGNATURE = jinja2.Template( """ -void {{func_name}}(half* output, - const half* input, +void {{func_name}}(void* output, + const void* input, const int64_t* indices, const {{index_type}} batch_num, const {{index_type}} indices_num, @@ -74,7 +69,7 @@ {{indent}} {{indices_num}}, {{indent}} {{instance_size}}, {{indent}} {{gather_dim_size}}, -{{indent}} global_workspace, stream /* default stream */ +{{indent}} global_workspace_, stream /* default stream */ {{indent}}); """ ) @@ -156,12 +151,10 @@ def gen_function_call(func_attrs: Dict[str, Any], indent=" ", is_cuda=False) -> assert len(func_attrs["outputs"]) == 1 assert len(func_attrs["inputs"]) == 2 - output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=func_attrs["outputs"][0]._attrs["name"], is_cuda=is_cuda - ) - input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=func_attrs["inputs"][0]._attrs["name"], is_cuda=is_cuda - ) + output_name = func_attrs["outputs"][0]._attrs["name"] + + input_name = func_attrs["inputs"][0]._attrs["name"] + indices_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render( name=func_attrs["inputs"][1]._attrs["name"] ) @@ -208,6 +201,9 @@ def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], index_type=index_type, prefix=prefix ), + dtype=backend_spec.dtype_to_backend_dtype[ + func_attrs["inputs"][0]._attrs["dtype"] + ], ) diff --git a/python/aitemplate/backend/common/tensor/permute021_common.py b/python/aitemplate/backend/common/tensor/permute021_common.py index db5ed63fd..30ab97b80 100644 --- a/python/aitemplate/backend/common/tensor/permute021_common.py +++ b/python/aitemplate/backend/common/tensor/permute021_common.py @@ -28,15 +28,15 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{lib_dtype}}*, - {{lib_dtype}}*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, - {{prefix}}Stream_t + const void* /*input*/, + void* /* output */, + int64_t* /* x_dim0 */, + int64_t* /* x_dim1 */, + int64_t* /* x_dim2 */, + int64_t* /* y_dim0 */, + int64_t* /* y_dim1 */, + int64_t* /* y_dim2 */, + {{prefix}}Stream_t /* stream */ ); """ ) @@ -44,8 +44,8 @@ FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{{func_name}}( -{{indent}} ({{lib_dtype}}*)({{in_ptr}}), -{{indent}} ({{lib_dtype}}*)({{out_ptr}}), +{{indent}} {{in_ptr}}, +{{indent}} {{out_ptr}}, {{indent}} {{x_dim0}}, {{indent}} {{x_dim1}}, {{indent}} {{x_dim2}}, @@ -138,8 +138,8 @@ } } -void permute021_launcher({{lib_dtype}}* in_ptr, - {{lib_dtype}}* out_ptr, +void permute021_launcher(const void* in_ptr, + void* out_ptr, int x_dim0, int x_dim1, int x_dim2, @@ -151,8 +151,8 @@ dim3 grid((c + 31)/32, (h*w + 31)/32, n); dim3 block(32, 8); nhwc_to_nchw_kernel<{{lib_dtype}}><<>>( - ({{lib_dtype}}*)out_ptr, - (const {{lib_dtype}}*)in_ptr, + static_cast<{{lib_dtype}}*>(out_ptr), + static_cast(in_ptr), n, h, w, @@ -162,8 +162,8 @@ } // namespace void {{function_name}} ( - {{lib_dtype}}* in_ptr, - {{lib_dtype}}* out_ptr, + const void* in_ptr, + void* out_ptr, int64_t* x_dim0, int64_t* x_dim1, int64_t* x_dim2, @@ -258,11 +258,8 @@ def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: """ func_name = func_attrs["name"] - x = func_attrs["inputs"][0] - xdtype = x._attrs["dtype"] return FUNC_DECL_TEMPLATE.render( func_name=func_name, - lib_dtype=backend_spec.dtype_to_lib_type(xdtype), prefix=backend_spec.prefix, ) @@ -286,7 +283,6 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> x = func_attrs["inputs"][0] xshape = x._attrs["shape"] - xdtype = x._attrs["dtype"] y = func_attrs["outputs"][0] yshape = y._attrs["shape"] return FUNC_CALL_TEMPLATE.render( @@ -300,5 +296,4 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> y_dim1="&" + yshape[1]._attrs["name"], y_dim2="&" + yshape[2]._attrs["name"], indent=indent, - lib_dtype=backend_spec.dtype_to_lib_type(xdtype), ) diff --git a/python/aitemplate/backend/common/tensor/permute102_common.py b/python/aitemplate/backend/common/tensor/permute102_common.py index 807e65bef..7c367ed8a 100644 --- a/python/aitemplate/backend/common/tensor/permute102_common.py +++ b/python/aitemplate/backend/common/tensor/permute102_common.py @@ -36,14 +36,14 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{lib_dtype}}*, - {{lib_dtype}}*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, + const void* /* input */, + void* /* output */, + int64_t* /* x_dim0 */, + int64_t* /* x_dim1 */, + int64_t* /* x_dim2 */, + int64_t* /* y_dim0 */, + int64_t* /* y_dim1 */, + int64_t* /* y_dim2 */, {{prefix}}Stream_t ); """ @@ -52,8 +52,8 @@ FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{{func_name}}( -{{indent}} ({{lib_dtype}}*){{in_ptr}}, -{{indent}} ({{lib_dtype}}*){{out_ptr}}, +{{indent}} {{in_ptr}}, +{{indent}} {{out_ptr}}, {{indent}} {{x_dim0}}, {{indent}} {{x_dim1}}, {{indent}} {{x_dim2}}, @@ -149,8 +149,8 @@ } } -void permute102_launcher({{lib_dtype}}* in_ptr, - {{lib_dtype}}* out_ptr, +void permute102_launcher(const void* in_ptr, + void* out_ptr, int x_dim0, int x_dim1, int x_dim2, @@ -162,8 +162,8 @@ dim3 grid((c + TILE_SIZE - 1)/TILE_SIZE, (h*w + TILE_SIZE -1)/TILE_SIZE, n); dim3 block(TILE_SIZE, TILE_SIZE / CH_K); nhwc_to_nchw_kernel<{{lib_dtype}}><<>>( - out_ptr, - (const {{lib_dtype}}*)in_ptr, + static_cast<{{lib_dtype}}*>(out_ptr), + static_cast(in_ptr), n, h, w, @@ -173,8 +173,8 @@ } // namespace void {{function_name}} ( - {{lib_dtype}}* in_ptr, - {{lib_dtype}}* out_ptr, + const void* in_ptr, + void* out_ptr, int64_t* x_dim0, int64_t* x_dim1, int64_t* x_dim2, @@ -265,11 +265,8 @@ def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: Function declaration """ func_name = func_attrs["name"] - x = func_attrs["inputs"][0] - xdtype = x._attrs["dtype"] return FUNC_DECL_TEMPLATE.render( func_name=func_name, - lib_dtype=backend_spec.dtype_to_lib_type(xdtype), prefix=backend_spec.prefix, ) @@ -292,7 +289,6 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> """ x = func_attrs["inputs"][0] xshape = x._attrs["shape"] - xdtype = x._attrs["dtype"] y = func_attrs["outputs"][0] yshape = y._attrs["shape"] return FUNC_CALL_TEMPLATE.render( @@ -306,5 +302,4 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> y_dim1="&" + yshape[1]._attrs["name"], y_dim2="&" + yshape[2]._attrs["name"], indent=indent, - lib_dtype=backend_spec.dtype_to_lib_type(xdtype), ) diff --git a/python/aitemplate/backend/common/tensor/permute210_common.py b/python/aitemplate/backend/common/tensor/permute210_common.py index fa1d5d25a..35894b315 100644 --- a/python/aitemplate/backend/common/tensor/permute210_common.py +++ b/python/aitemplate/backend/common/tensor/permute210_common.py @@ -35,15 +35,15 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{lib_dtype}}*, - {{lib_dtype}}*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, - int64_t*, - {{prefix}}Stream_t + const void* /* input */, + void* /* output */, + int64_t* /* x_dim0 */, + int64_t* /* x_dim1 */, + int64_t* /* x_dim2 */, + int64_t* /* y_dim0 */, + int64_t* /* y_dim1 */, + int64_t* /* y_dim2 */, + {{prefix}}Stream_t /* stream */ ); """ ) @@ -51,8 +51,8 @@ FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{{func_name}}( -{{indent}} static_cast<{{lib_dtype}}*>({{in_ptr}}), -{{indent}} static_cast<{{lib_dtype}}*>({{out_ptr}}), +{{indent}} {{in_ptr}}, +{{indent}} {{out_ptr}}, {{indent}} {{x_dim0}}, {{indent}} {{x_dim1}}, {{indent}} {{x_dim2}}, @@ -158,8 +158,8 @@ } } -void permute210_launcher({{lib_dtype}}* in_ptr, - {{lib_dtype}}* out_ptr, +void permute210_launcher(const void* in_ptr, + void* out_ptr, int x_dim0, int x_dim1, int x_dim2, @@ -167,8 +167,8 @@ dim3 grid((x_dim2 + (TILE_SIZE-1))/TILE_SIZE, x_dim1, (x_dim0 + (TILE_SIZE-1))/TILE_SIZE); dim3 block(TILE_SIZE, TILE_SIZE/4); permute210_kernel<{{lib_dtype}}><<>>( - out_ptr, - (const {{lib_dtype}}*)in_ptr, + static_cast<{{lib_dtype}}*>(out_ptr), + static_cast(in_ptr), x_dim0, x_dim1, x_dim2 @@ -177,8 +177,8 @@ } // namespace void {{function_name}} ( - {{lib_dtype}}* in_ptr, - {{lib_dtype}}* out_ptr, + const void* in_ptr, + void* out_ptr, int64_t* x_dim0, int64_t* x_dim1, int64_t* x_dim2, @@ -244,12 +244,9 @@ def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: Function declaration """ func_name = func_attrs["name"] - x = func_attrs["inputs"][0] - xdtype = x._attrs["dtype"] return FUNC_DECL_TEMPLATE.render( func_name=func_name, prefix=backend_spec.prefix, - lib_dtype=backend_spec.dtype_to_lib_type(xdtype), ) @@ -271,7 +268,6 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> """ x = func_attrs["inputs"][0] xshape = x._attrs["shape"] - xdtype = x._attrs["dtype"] y = func_attrs["outputs"][0] yshape = y._attrs["shape"] return FUNC_CALL_TEMPLATE.render( @@ -285,5 +281,4 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> y_dim1="&" + yshape[1]._attrs["name"], y_dim2="&" + yshape[2]._attrs["name"], indent=indent, - lib_dtype=backend_spec.dtype_to_lib_type(xdtype), ) diff --git a/python/aitemplate/backend/common/tensor/slice_common.py b/python/aitemplate/backend/common/tensor/slice_common.py index fb17116de..f42f213f2 100644 --- a/python/aitemplate/backend/common/tensor/slice_common.py +++ b/python/aitemplate/backend/common/tensor/slice_common.py @@ -17,11 +17,6 @@ """ import jinja2 -CAST_TO_CONST_HALF_PTR_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") - - -CAST_TO_HALF_PTR_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") - SHAPE_UPDATE_FUNC = jinja2.Template( """ @@ -59,9 +54,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{elem_output_type}} * /*output*/, + void * /*output*/, int64_t *[] /*output_shape*/, - const {{elem_input_type}} *[] /*inputs*/, + const void *[] /*inputs*/, const int64_t *[] /*input_shapes*/, const int64_t *[] /*orig_slice_start_indices*/, const int64_t *[] /*orig_slice_end_indices*/, @@ -203,9 +198,9 @@ int64_t scatter_dim_size = slice_meta_data.dim_sizes[block_y]; int64_t scatter_offset = slice_meta_data.offsets[block_y]; - unsigned read_t_sz = sizeof(READ_T); - unsigned elem_t_sz = sizeof(ELEM_T); - assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); + constexpr unsigned read_t_sz = sizeof(READ_T); + constexpr unsigned elem_t_sz = sizeof(ELEM_T); + static_assert(read_t_sz >= elem_t_sz && (read_t_sz % elem_t_sz == 0)); {{index_type}} n_of_elem_t = read_t_sz / elem_t_sz; // number of READ_T elements per thread {{index_type}} reads_per_thread_in_read_t = ElemsPerThread / n_of_elem_t; @@ -275,7 +270,6 @@ template static LoadVecType get_input_vec_type( const int64_t *output_strides, - const ELEM_T *input, const int64_t *input_shape, const int64_t *input_strides, const int64_t *slice_start_indices, @@ -404,7 +398,6 @@ for ({{index_type}} i = 0; i < NumInputs; i++) { LoadVecType vec_type = get_input_vec_type( scatter_meta_data.output_strides, - inputs[i], input_shapes[i], slice_meta_data.input_strides[i], slice_start_indices[i].data(), @@ -516,7 +509,7 @@ {{indent}} {{num_inputs}}/*NumInputs*/, {{indent}} {{elems_per_thread}}/*ElemsPerThread*/, {{indent}} {{threads_per_block}}/*ThreadsPerBlock*/>( -{{indent}} output, local_output_shape, inputs, input_shapes, +{{indent}} static_cast<{{elem_type}}*>(output), local_output_shape, reinterpret_cast(inputs), input_shapes, {{indent}} slice_start_indices, slice_end_indices, scatter_dim, stream); {{indent}} return; {{indent}}} @@ -529,9 +522,9 @@ {{kernel_src}} void {{func_name}}( - {{elem_output_type}} *output, + void *output, int64_t *output_shape[], - const {{elem_input_type}} *inputs[], + const void *inputs[], const int64_t *input_shapes[], const int64_t *orig_slice_start_indices[], const int64_t *orig_slice_end_indices[], @@ -615,7 +608,7 @@ {{indent}}{ {{output_shape_def}} -{{indent}} const half *inputs[] = { +{{indent}} const void *inputs[] = { {{indent}} {{inputs}} {{indent}} }; @@ -687,14 +680,8 @@ def gen_function_decl(func_attrs, backend_spec): str Rendered function declaration. """ - x = func_attrs["inputs"][0] - y = func_attrs["outputs"][0] - input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) - output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) return FUNC_DECL_TEMPLATE.render( func_name=func_attrs["name"], - elem_output_type=output_type, - elem_input_type=input_type, index_type=backend_spec.index_type, prefix=backend_spec.prefix, ) @@ -742,6 +729,9 @@ def gen_function( input_type = backend_spec.dtype_to_backend_type(x._attrs["dtype"]) output_type = backend_spec.dtype_to_backend_type(y._attrs["dtype"]) + if input_type != output_type: + raise NotImplementedError("input type must equal to output type") + # TODO: consider to add profiling paths for tuning # elems_per_thread and threads_per_block exec_paths = EXEC_COND_TEMPLATE.render( @@ -774,8 +764,6 @@ def gen_function( return SRC_TEMPLATE.render( kernel_src=kernel_src, func_name=func_attrs["name"], - elem_input_type=input_type, - elem_output_type=output_type, shape_function=shape_func, exec_paths=exec_paths, index_type=backend_spec.index_type, @@ -827,12 +815,7 @@ def gen_function_call( x = inputs[0] y = outputs[0] - input_names = ",\n ".join( - [ - backend_spec.cast_to_const_half_ptr_template.render(name=i._attrs["name"]) - for i in inputs - ] - ) + input_names = ",\n ".join([i._attrs["name"] for i in inputs]) input_shape_defs = [] input_shape_names = [] @@ -880,14 +863,11 @@ def gen_function_call( indent=indent, output_name=y._attrs["name"], output_dim_refs=y_dim_refs ) - casted_y_ptr = backend_spec.cast_to_half_ptr_template.render(name=y._attrs["name"]) - return FUNC_CALL_TEMPLATE.render( indent=indent, func_name=func_name, - output_elem_type=backend_spec.dtype_to_backend_type(y._attrs["dtype"]), output_name=y._attrs["name"], - output_ptr=casted_y_ptr, + output_ptr=y._attrs["name"], output_shape_def=output_shape_def, inputs=input_names, input_shape_defs="".join(input_shape_defs), diff --git a/python/aitemplate/backend/common/tensor/topk_common.py b/python/aitemplate/backend/common/tensor/topk_common.py index 6b82ef531..044833bc0 100644 --- a/python/aitemplate/backend/common/tensor/topk_common.py +++ b/python/aitemplate/backend/common/tensor/topk_common.py @@ -21,9 +21,6 @@ import jinja2 -from ... import builder -from ...target import Target - # pylint: disable=C0301 FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") @@ -40,7 +37,7 @@ {{func_signature}} { - topk_launcher(stream, elem_cnt, instance_size, instance_num, top_k, input, workspace, output); + topk_launcher<{{dtype}}>(stream, elem_cnt, instance_size, instance_num, top_k, input, workspace, output); } """ ) @@ -64,10 +61,10 @@ int instance_num = std::stoi(argv[3]); float runtime_ms = 0; - const int64_t sorted_in_aligned_bytes = GetAlignedSize(elem_cnt * sizeof(half)); + const int64_t sorted_in_aligned_bytes = GetAlignedSize(elem_cnt * sizeof({{dtype}})); const int64_t indices_aligned_bytes = GetAlignedSize(elem_cnt * sizeof(int64_t)); const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes; - int64_t temp_storage_bytes = InferTempStorageForSortPairsDescending(instance_size, instance_num); + int64_t temp_storage_bytes = InferTempStorageForSortPairsDescending<{{dtype}}, int64_t>(instance_size, instance_num); GLOBAL_WORKSPACE_SIZE = GetAlignedSize(sorted_in_aligned_bytes + indices_aligned_bytes + sorted_indices_aligned_bytes + temp_storage_bytes); std::cout << "TIME:" << runtime_ms << std::endl; std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; @@ -78,7 +75,7 @@ FUNC_SIGNATURE = jinja2.Template( """ void {{func_name}}(int64_t* output, - const half* input, + const void* input, const {{index_type}} elem_cnt, const {{index_type}} instance_size, const {{index_type}} instance_num, @@ -102,7 +99,7 @@ {{indent}} {{instance_size}}, {{indent}} {{instance_num}}, {{indent}} {{top_k}}, -{{indent}} global_workspace, stream /* default stream */ +{{indent}} global_workspace_, stream /* default stream */ {{indent}}); """ ) @@ -624,12 +621,14 @@ def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> """ index_type = backend_spec.index_type prefix = backend_spec.prefix + dtype = backend_spec.dtype_to_backend_type(func_attrs["inputs"][0]._attrs["dtype"]) return FUNC_TEMPLATE.render( header_files=header_files, func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], index_type=index_type, prefix=prefix ), kernel=KERNEL_TEMPLATE.render(cub=backend_spec.cub, prefix=prefix), + dtype=dtype, ) @@ -681,9 +680,7 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> output_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render( name=func_attrs["outputs"][0]._attrs["name"] ) - input_name = backend_spec.cast_to_half_ptr_template.render( - name=func_attrs["inputs"][0]._attrs["name"] - ) + input_name = func_attrs["inputs"][0]._attrs["name"] x = func_attrs["inputs"][0] xshape = x._attrs["shape"] @@ -754,16 +751,16 @@ def gen_profiler( file_pairs = [] index_type = backend_spec.index_type prefix = backend_spec.prefix + dtype = backend_spec.dtype_to_backend_type(func_attrs["inputs"][0]._attrs["dtype"]) + code = PROFILER_TEMPLATE.render( header_files=header_files, func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], index_type=index_type, prefix=prefix ), kernel=KERNEL_TEMPLATE.render(cub=backend_spec.cub, prefix=prefix), + dtype=dtype, ) op_name = func_attrs["op"] add_profiler(file_pairs, workdir, op_type, op_name, code) - # build - target = Target.current() - compile_engine = builder.Builder() - compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) + return file_pairs diff --git a/python/aitemplate/backend/common/tensor_accessor_codegen.py b/python/aitemplate/backend/common/tensor_accessor_codegen.py index e2e873647..6d6174d27 100644 --- a/python/aitemplate/backend/common/tensor_accessor_codegen.py +++ b/python/aitemplate/backend/common/tensor_accessor_codegen.py @@ -22,6 +22,7 @@ import jinja2 from ...compiler.tensor_accessor import TensorAccessor +from ...utils import alignment from ..target import Target # Template used to transform a Python TensorAccessor object @@ -71,27 +72,6 @@ def get_libs() -> str: ) -# Currently read4, add2 is best for both backend, so two backend seems identical. -# They may diverge when we got deeper understanding / further optimization. -ALIGNMENTS = [ - 8, - 4, - 2, - 1, -] - - -def _find_max_alignment(number: int) -> int: - """ - Return the first alignment value that meets the alignment requirement - for accessing the `number` of elements. - """ - for alignment in ALIGNMENTS: - if number % alignment == 0: - return alignment - return 1 - - def find_max_alignment_for_accessor(accessor: TensorAccessor) -> int: """the max alignment value that meets the requirement specified by the accessor @@ -105,17 +85,21 @@ def find_max_alignment_for_accessor(accessor: TensorAccessor) -> int: int the max alignment value """ - alignment = _find_max_alignment(accessor.offset) + align = alignment.find_max_alignment(accessor.offset) if not accessor.is_contiguous: - alignment = min( - alignment, - _find_max_alignment(accessor.original_total_elements_from_stride_dim), + align = min( + align, + alignment.find_max_alignment( + accessor.original_total_elements_from_stride_dim + ), ) - alignment = min( - alignment, - _find_max_alignment(accessor.actual_total_elements_from_stride_dim), + align = min( + align, + alignment.find_max_alignment( + accessor.actual_total_elements_from_stride_dim + ), ) - return alignment + return align def find_max_alignment_for_accessors(accessors: List[TensorAccessor]) -> int: @@ -132,11 +116,11 @@ def find_max_alignment_for_accessors(accessors: List[TensorAccessor]) -> int: int the max alignment value """ - alignment = max(ALIGNMENTS) + align = max(alignment.ALIGNMENTS) # Handle accessors for accessor in accessors: - alignment = min(alignment, find_max_alignment_for_accessor(accessor)) - return alignment + align = min(align, find_max_alignment_for_accessor(accessor)) + return align def find_max_alignment(num_elements: int, accessors: List[TensorAccessor]) -> int: @@ -158,6 +142,6 @@ def find_max_alignment(num_elements: int, accessors: List[TensorAccessor]) -> in the max alignment value """ # get initial alignment based on the number of elements being accessed - alignment = _find_max_alignment(num_elements) + align = alignment.find_max_alignment(num_elements) accessor_alignment = find_max_alignment_for_accessors(accessors) - return min(alignment, accessor_alignment) + return min(align, accessor_alignment) diff --git a/python/aitemplate/backend/common/upsampling2d_common.py b/python/aitemplate/backend/common/upsampling2d_common.py index 6d7aadd3c..736ee6482 100644 --- a/python/aitemplate/backend/common/upsampling2d_common.py +++ b/python/aitemplate/backend/common/upsampling2d_common.py @@ -23,12 +23,12 @@ EXEC_TEMPLATE = jinja2.Template( """ -{{indent}}bilinear_upsampling_luncher( -{{indent}} in_ptr, +{{indent}}bilinear_upsampling_launcher( +{{indent}} static_cast(in_ptr), {% if bias_add %} - {{indent}} res_ptr, + {{indent}} static_cast(res_ptr), {% endif %} -{{indent}} out_ptr, +{{indent}} static_cast<{{dtype}}*>(out_ptr), {{indent}} NI, {{indent}} HI, {{indent}} WI, @@ -200,11 +200,12 @@ return (n + m - 1) / m; } -void bilinear_upsampling_luncher({{elem_input_type}}* input, +template +void bilinear_upsampling_launcher(const ELEM_T* input, {% if bias_add %} - {{elem_input_type}}* input_res, + const ELEM_T* input_res, {% endif %} - {{elem_output_type}}* output, + ELEM_T* output, const {{index_type}} N, const {{index_type}} H, const {{index_type}} W, @@ -257,11 +258,11 @@ } // namespace void {{function_name}} ( - {{elem_input_type}}* in_ptr, + const void* in_ptr, {% if bias_add %} - {{elem_input_type}}* res_ptr, + const void* res_ptr, {% endif %} - {{elem_output_type}}* out_ptr, + void* out_ptr, {{index_type}}* batch, {{index_type}}* in_h, {{index_type}}* in_w, @@ -284,11 +285,11 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{elem_input_type}}*, + const void*, {% if bias_add %} - {{elem_input_type}}*, + const void*, {% endif %} - {{elem_output_type}}*, + void*, {{index_type}}*, {{index_type}}*, {{index_type}}*, @@ -304,11 +305,11 @@ FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{{func_name}}( -{{indent}} static_cast<{{elem_input_type}}*>({{in_ptr}}), +{{indent}} {{in_ptr}}, {% if bias_add %} - {{indent}} static_cast<{{elem_input_type}}*>({{res_ptr}}), +{{indent}} {{res_ptr}}, {% endif %} -{{indent}} static_cast<{{elem_output_type}}*>({{out_ptr}}), +{{indent}} {{out_ptr}}, {{indent}} {{p_batch}}, {{indent}} {{p_in_h}}, {{indent}} {{p_in_w}}, @@ -337,16 +338,10 @@ def gen_function_decl(func_attrs, backend_spec, bias_add=False): str Rendered function declaration stmt """ - x = func_attrs["inputs"][0] - y = func_attrs["outputs"][0] - input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) - output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) return FUNC_DECL_TEMPLATE.render( index_type=backend_spec.index_type, prefix=backend_spec.prefix, func_name=func_attrs["name"], - elem_input_type=input_type, - elem_output_type=output_type, bias_add=bias_add, ) @@ -383,14 +378,10 @@ def gen_function_call(func_attrs, backend_spec, indent=" ", bias_add=False): xshape = x._attrs["shape"] y = func_attrs["outputs"][0] yshape = y._attrs["shape"] - input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) - output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) if bias_add: r = func_attrs["inputs"][1] return FUNC_CALL_TEMPLATE.render( func_name=func_attrs["name"], - elem_input_type=input_type, - elem_output_type=output_type, index_type=backend_spec.index_type, in_ptr=x._attrs["name"], res_ptr=r._attrs["name"], @@ -408,8 +399,6 @@ def gen_function_call(func_attrs, backend_spec, indent=" ", bias_add=False): else: return FUNC_CALL_TEMPLATE.render( func_name=func_attrs["name"], - elem_input_type=input_type, - elem_output_type=output_type, index_type=backend_spec.index_type, in_ptr=x._attrs["name"], out_ptr=y._attrs["name"], diff --git a/python/aitemplate/backend/common/vision_ops/efficient_nms_common.py b/python/aitemplate/backend/common/vision_ops/efficient_nms_common.py index 8431e5d87..fd0ca6c50 100644 --- a/python/aitemplate/backend/common/vision_ops/efficient_nms_common.py +++ b/python/aitemplate/backend/common/vision_ops/efficient_nms_common.py @@ -21,14 +21,10 @@ import jinja2 -from ... import builder -from ...target import Target from .efficient_nms_kernel import kernel # pylint: disable=C0301 -FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") - FUNC_TEMPLATE = jinja2.Template( """ {{header_files}} @@ -96,7 +92,7 @@ int batchSize = std::stoi(argv[1]); int numScoreElements = std::stoi(argv[2]); int numClasses = std::stoi(argv[3]); - GLOBAL_WORKSPACE_SIZE = EfficientNMSWorkspaceSize(batchSize, numScoreElements, numClasses); + GLOBAL_WORKSPACE_SIZE = EfficientNMSWorkspaceSize<{{elem_input_type}}>(batchSize, numScoreElements, numClasses); std::cout << "TIME:" << runtime_ms << std::endl; std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; @@ -106,12 +102,12 @@ FUNC_SIGNATURE = jinja2.Template( """ -void {{func_name}}(int64_t* num_detections, - half* detection_boxes, - half* detection_scores, - int64_t* detection_classe, - const half* proposals, - const half* fgScores, +void {{func_name}}(void* num_detections, + void* detection_boxes, + void* detection_scores, + void* detection_classe, + const void* proposals, + const void* fgScores, int64_t* batch, int64_t* num_rois, int64_t* num_classes, @@ -147,7 +143,7 @@ {{indent}} {{nmsMaxOut}}, {{indent}} {{iouThreshold}}, {{indent}} {{minBoxSize}}, -{{indent}} global_workspace, stream /* default stream */ +{{indent}} global_workspace_, stream /* default stream */ {{indent}}); """ ) @@ -155,9 +151,16 @@ def gen_function(func_attrs: Dict[str, Any], header_files, backend_spec) -> str: """the function for generating nms kernel""" + elem_input_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) return FUNC_TEMPLATE.render( header_files=header_files, - kernel=kernel.render(prefix=backend_spec.prefix, cub=backend_spec.cub), + kernel=kernel.render( + prefix=backend_spec.prefix, + cub=backend_spec.cub, + elem_input_type=elem_input_type, + ), func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], prefix=backend_spec.prefix ), @@ -178,21 +181,12 @@ def gen_function_call(func_attrs, backend_spec, indent=" "): assert len(func_attrs["outputs"]) == 4 assert len(func_attrs["inputs"]) == 2 - num_detections = FUNC_CALL_INT64_PARAM_TEMPLATE.render( - name=func_attrs["outputs"][0]._attrs["name"] - ) - detection_boxes = backend_spec.cast_to_half_ptr_template.render( - name=func_attrs["outputs"][1]._attrs["name"] - ) - detection_scores = backend_spec.cast_to_half_ptr_template.render( - name=func_attrs["outputs"][2]._attrs["name"] - ) - detection_classes = FUNC_CALL_INT64_PARAM_TEMPLATE.render( - name=func_attrs["outputs"][3]._attrs["name"] - ) + num_detections = func_attrs["outputs"][0]._attrs["name"] + detection_boxes = func_attrs["outputs"][1]._attrs["name"] + detection_scores = func_attrs["outputs"][2]._attrs["name"] + detection_classes = func_attrs["outputs"][3]._attrs["name"] (input_name, score_name) = ( - backend_spec.cast_to_half_ptr_template.render(name=input_tensor._attrs["name"]) - for input_tensor in func_attrs["inputs"] + input_tensor._attrs["name"] for input_tensor in func_attrs["inputs"] ) x = func_attrs["inputs"][0] @@ -235,16 +229,21 @@ def gen_profiler(func_attrs, workdir, header_files, backend_spec): """the function for generating profiler for nms op""" op_type = func_attrs["op"] file_pairs = [] + elem_input_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) code = PROFILER_TEMPLATE.render( header_files=header_files, - kernel=kernel.render(prefix=backend_spec.prefix, cub=backend_spec.cub), + elem_input_type=elem_input_type, + kernel=kernel.render( + prefix=backend_spec.prefix, + cub=backend_spec.cub, + elem_input_type=elem_input_type, + ), func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], prefix=backend_spec.prefix ), ) op_name = func_attrs["op"] add_profiler(file_pairs, workdir, op_type, op_name, code) - # build - target = Target.current() - compile_engine = builder.Builder() - compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) + return file_pairs diff --git a/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py b/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py index 5d5631f14..c40b01e7c 100644 --- a/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py +++ b/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py @@ -1143,7 +1143,7 @@ if (param.scoreBits <= 0 || param.scoreBits > 10) { param.scoreBits = -1; } - EfficientNMSDispatch<__half>( + EfficientNMSDispatch<{{elem_input_type}}>( param, boxesInput, scoresInput, diff --git a/python/aitemplate/backend/common/vision_ops/nms_common.py b/python/aitemplate/backend/common/vision_ops/nms_common.py index 50cc5e356..53e2b6f31 100644 --- a/python/aitemplate/backend/common/vision_ops/nms_common.py +++ b/python/aitemplate/backend/common/vision_ops/nms_common.py @@ -21,8 +21,6 @@ import jinja2 -from ... import builder -from ...target import Target from .nms_kernel import KERNEL_TEMPLATE # pylint: disable=C0301 @@ -43,7 +41,9 @@ const int N = *batch; const int R = *num_rois; - nmsGpu(stream, N, R, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, fgScores, proposals, workspace, rois); + nmsGpu<{{elem_scores_type}}, {{elem_rois_type}}>( + stream, N, R, preNmsTop, nmsMaxOut, iouThreshold, minBoxSize, + fgScores, proposals, workspace, rois); } """ ) @@ -70,9 +70,9 @@ float runtime_ms = 0; const int64_t offsets_bytes = GetCudaAlignedSize((instance_num+1) * sizeof(int64_t)); - const int64_t scores_bytes = GetCudaAlignedSize(elem_cnt * sizeof(half)); - const int64_t boxes_bytes = GetCudaAlignedSize(elem_cnt * 4 * sizeof(half)); - int64_t temp_storage_bytes = InferTempStorageForSortPairsDescending(instance_num, instance_size); + const int64_t scores_bytes = GetCudaAlignedSize(elem_cnt * sizeof({{elem_scores_type}})); + const int64_t boxes_bytes = GetCudaAlignedSize(elem_cnt * 4 * sizeof({{elem_rois_type}})); + int64_t temp_storage_bytes = InferTempStorageForSortPairsDescending<{{elem_scores_type}}, int64_t>(instance_num, instance_size); GLOBAL_WORKSPACE_SIZE = GetCudaAlignedSize(offsets_bytes + scores_bytes + boxes_bytes + temp_storage_bytes); @@ -84,9 +84,9 @@ FUNC_SIGNATURE = jinja2.Template( """ -void {{func_name}}(half* rois, - const half* proposals, - const half* fgScores, +void {{func_name}}(void* rois, + const void* proposals, + const void* fgScores, int64_t* batch, int64_t* num_rois, const {{index_type}} preNmsTop, @@ -114,7 +114,7 @@ {{indent}} {{nmsMaxOut}}, {{indent}} {{iouThreshold}}, {{indent}} {{minBoxSize}}, -{{indent}} global_workspace, stream /* default stream */ +{{indent}} global_workspace_, stream /* default stream */ {{indent}}); """ ) @@ -129,8 +129,16 @@ def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> else: cuda_hmaxmin = False + elem_rois_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_scores_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][1]._attrs["dtype"] + ) return FUNC_TEMPLATE.render( T_SIZE=t_size, + elem_scores_type=elem_scores_type, + elem_rois_type=elem_rois_type, header_files=header_files, kernel=KERNEL_TEMPLATE.render( prefix=backend_spec.prefix, cub=backend_spec.cub, cuda_hmaxmin=cuda_hmaxmin @@ -159,12 +167,9 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent: str) -> assert len(func_attrs["outputs"]) == 1 assert len(func_attrs["inputs"]) == 2 - output_name = backend_spec.cast_to_half_ptr_template.render( - name=func_attrs["outputs"][0]._attrs["name"] - ) + output_name = func_attrs["outputs"][0]._attrs["name"] (input_name, score_name) = ( - backend_spec.cast_to_half_ptr_template.render(name=input_tensor._attrs["name"]) - for input_tensor in func_attrs["inputs"] + input_tensor._attrs["name"] for input_tensor in func_attrs["inputs"] ) x = func_attrs["inputs"][0] @@ -215,8 +220,16 @@ def gen_profiler( else: cuda_hmaxmin = False + elem_rois_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_scores_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][1]._attrs["dtype"] + ) code = PROFILER_TEMPLATE.render( T_SIZE=t_size, + elem_scores_type=elem_scores_type, + elem_rois_type=elem_rois_type, header_files=header_files, kernel=KERNEL_TEMPLATE.render( prefix=backend_spec.prefix, cub=backend_spec.cub, cuda_hmaxmin=cuda_hmaxmin @@ -229,7 +242,4 @@ def gen_profiler( ) op_name = func_attrs["op"] add_profiler(file_pairs, workdir, op_type, op_name, code) - # build - target = Target.current() - compile_engine = builder.Builder() - compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) + return file_pairs diff --git a/python/aitemplate/backend/common/vision_ops/roi_align_common.py b/python/aitemplate/backend/common/vision_ops/roi_align_common.py index b658b711f..d7c64d60e 100644 --- a/python/aitemplate/backend/common/vision_ops/roi_align_common.py +++ b/python/aitemplate/backend/common/vision_ops/roi_align_common.py @@ -23,10 +23,10 @@ EXEC_TEMPLATE = jinja2.Template( """ -{{indent}}roi_align_launcher( -{{indent}} in_ptr, -{{indent}} rois_ptr, -{{indent}} out_ptr, +{{indent}}roi_align_launcher<{{library_dtype}}, float, {{num_rois}}, {{pooled_size}}>( +{{indent}} static_cast(in_ptr), +{{indent}} static_cast(rois_ptr), +{{indent}} static_cast<{{library_dtype}}*>(out_ptr), {{indent}} NI, {{indent}} HI, {{indent}} WI, @@ -212,10 +212,10 @@ } -template -void roi_align_launcher({{elem_input_type}}* input, - {{elem_input_type}}* rois, - {{elem_output_type}}* output, +template +void roi_align_launcher(const LibraryT* input, + const LibraryT* rois, + LibraryT* output, const {{index_type}} N, const {{index_type}} H, const {{index_type}} W, @@ -243,9 +243,9 @@ } // namespace void {{function_name}} ( - {{elem_input_type}}* in_ptr, - {{elem_input_type}}* rois_ptr, - {{elem_output_type}}* out_ptr, + const void* in_ptr, + const void* rois_ptr, + void* out_ptr, {{index_type}}* batch, {{index_type}}* in_h, {{index_type}}* in_w, @@ -273,9 +273,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{elem_input_type}}*, - {{elem_input_type}}*, - {{elem_output_type}}*, + const void*, + const void*, + void*, {{index_type}}*, {{index_type}}*, {{index_type}}*, @@ -295,9 +295,9 @@ FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{{func_name}}( -{{indent}} static_cast<{{elem_input_type}}*>({{in_ptr}}), -{{indent}} static_cast<{{elem_input_type}}*>({{rois_ptr}}), -{{indent}} static_cast<{{elem_output_type}}*>({{out_ptr}}), +{{indent}} {{in_ptr}}, +{{indent}} {{rois_ptr}}, +{{indent}} {{out_ptr}}, {{indent}} {{p_batch}}, {{indent}} {{p_in_h}}, {{indent}} {{p_in_w}}, @@ -330,16 +330,10 @@ def gen_function_decl(func_attrs, backend_spec): str Rendered function declaration stmt """ - x = func_attrs["inputs"][0] - y = func_attrs["outputs"][0] - input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) - output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) return FUNC_DECL_TEMPLATE.render( index_type=backend_spec.index_type, prefix=backend_spec.prefix, func_name=func_attrs["name"], - elem_input_type=input_type, - elem_output_type=output_type, ) @@ -364,9 +358,6 @@ def gen_function_call(func_attrs, backend_spec, indent=" "): y = func_attrs["outputs"][0] yshape = y._attrs["shape"] - input_type = backend_spec.dtype_to_lib_type(x._attrs["dtype"]) - output_type = backend_spec.dtype_to_lib_type(y._attrs["dtype"]) - return FUNC_CALL_TEMPLATE.render( func_name=func_attrs["name"], in_ptr=x._attrs["name"], @@ -386,7 +377,5 @@ def gen_function_call(func_attrs, backend_spec, indent=" "): if func_attrs["continuous_coordinate"] else "false", backend_spec=backend_spec, - elem_input_type=input_type, - elem_output_type=output_type, indent=indent, ) diff --git a/python/aitemplate/backend/cuda/__init__.py b/python/aitemplate/backend/cuda/__init__.py index 38586aab5..f2ff7c11f 100644 --- a/python/aitemplate/backend/cuda/__init__.py +++ b/python/aitemplate/backend/cuda/__init__.py @@ -19,6 +19,7 @@ from . import cuda_common, lib_template, target_def, utils from .common import * from .conv2d import * +from .conv3d import * from .elementwise import * from .embedding import * from .gemm_special import * diff --git a/python/aitemplate/backend/cuda/attention/__init__.py b/python/aitemplate/backend/cuda/attention/__init__.py index 61a47c3ad..9636980b4 100644 --- a/python/aitemplate/backend/cuda/attention/__init__.py +++ b/python/aitemplate/backend/cuda/attention/__init__.py @@ -15,6 +15,6 @@ """ cuda flash_attention module init """ -from . import flash_attention +from . import flash_attention, mem_eff_attention -__all__ = ["flash_attention"] +__all__ = ["flash_attention", "mem_eff_attention"] diff --git a/python/aitemplate/backend/cuda/attention/flash_attention.py b/python/aitemplate/backend/cuda/attention/flash_attention.py index b2fe5c0ca..55d781ceb 100644 --- a/python/aitemplate/backend/cuda/attention/flash_attention.py +++ b/python/aitemplate/backend/cuda/attention/flash_attention.py @@ -23,10 +23,6 @@ # pylint: disable=C0301 -FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( - "reinterpret_cast(&({{name}}->raw()))" -) - FUNC_CALL_INT32_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") FUNC_CALL_FP32_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") @@ -202,8 +198,8 @@ FUNC_SIGNATURE = jinja2.Template( """ -void {{func_name}}(half* output, - const half* qkv, +void {{func_name}}(void* output, + const void* qkv, const int* cu_seqlens, float* softmax_lse, float* o_tmp, @@ -275,13 +271,9 @@ def flash_attention_gen_function_call(func_attrs, indent=" "): assert len(func_attrs["outputs"]) == 1 assert len(func_attrs["inputs"]) == 2 - output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=func_attrs["outputs"][0]._attrs["name"] - ) + output_name = func_attrs["outputs"][0]._attrs["name"] - qkv_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=func_attrs["inputs"][0]._attrs["name"] - ) + qkv_name = func_attrs["inputs"][0]._attrs["name"] seqlens_name = FUNC_CALL_INT32_PARAM_TEMPLATE.render( name=func_attrs["inputs"][1]._attrs["name"] @@ -303,8 +295,8 @@ def flash_attention_gen_function_call(func_attrs, indent=" "): output=output_name, qkv=qkv_name, cu_seqlens=seqlens_name, - softmax_lse="reinterpret_cast(global_workspace)", - o_tmp="reinterpret_cast(global_workspace + {} * sizeof(float))".format( + softmax_lse="reinterpret_cast(global_workspace_)", + o_tmp="reinterpret_cast(global_workspace_ + {} * sizeof(float))".format( batch_size * num_heads * seq_len ), batch_size=batch_size, diff --git a/python/aitemplate/backend/cuda/attention/mem_eff_attention.py b/python/aitemplate/backend/cuda/attention/mem_eff_attention.py new file mode 100644 index 000000000..3948182d2 --- /dev/null +++ b/python/aitemplate/backend/cuda/attention/mem_eff_attention.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +attention kernel codegen for CUDA. +""" +from typing import Any, Dict + +import jinja2 + +from ... import registry +from ...backend_spec import CUDASpec + +# pylint: disable=C0301 + +FUNC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include "cutlass/cutlass.h" +#include "kernel_forward.h" + +{{func_signature}} +{ + + /* + problem_sizes0 [b, m, n, k] + [head_number * batch_size, m, mkv, k0] + [head_number * batch_size, seq_length, seq_length_kv, head_size] + + problem_sizes1 + [head_number * batch_size, m, k1, mkv] + [head_number * batch_size, seq_length, head_size_v, seq_length_kv] + + m = seq_len + n = seq_len + k = head_size + + Q: B, M, K + K: B, N, K + P: B, M, N + V: B, N, K + O: B, M, K + output: bs, num_head, seq_len, head_size + */ + + + using ArchTag = cutlass::arch::Sm80; + constexpr bool kIs64x64 = {{kIs64x64}}; + constexpr bool kSingleValueIteration = {{kSingleValueIteration}}; + + // Set grid size + constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32; + constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128; + if (kIs64x64 && head_size_v > kKeysPerBlock) { + std::cerr << "WARNING: you will get better performance with `kIs64x64=false`"; + } + if (kSingleValueIteration && head_size_v > kKeysPerBlock) { + std::cerr << "ERROR : Use kSingleValueIteration to keep output in RF. " \ + "This requires to have `head_size <= kKeysPerBlock` " \ + "but head_size_v=" << head_size_v << " and kKeysPerBlock=" << kKeysPerBlock << ""; + return; + } + if (!kSingleValueIteration && head_size_v <= kKeysPerBlock) { + std::cerr << "WARNING: you will get better performance with `kSingleValueIteration=true` (keeps the output in RF rather than GMEM)"; + } + + using Attention = AttentionKernel< + {{elem_input_type}}, // scalar_t + ArchTag, + true, // memory is aligned + kQueriesPerBlock, + kKeysPerBlock, + kSingleValueIteration + >; + + int block_O_size = (*batch_size) * seq_len * num_heads * head_size_v; + typename Attention::Params p; + { + // set parameters + p.query_ptr = static_cast<{{elem_input_type}}*>(query); + p.key_ptr = static_cast<{{elem_input_type}}*>(key); + p.value_ptr = static_cast<{{elem_input_type}}*>(value); + p.logsumexp_ptr = nullptr; // Only needed for bw + p.output_accum_ptr = nullptr; + if (Attention::kNeedsOutputAccumulatorBuffer) { + p.output_accum_ptr = accum_ptr; + } + p.output_ptr = static_cast<{{elem_input_type}}*>(output); + + p.num_heads = num_heads; + p.num_batches = *batch_size; + p.head_dim = head_size; + p.head_dim_value = head_size_v; + p.num_queries = seq_len; + p.num_keys = seq_len_kv; + p.causal = is_causal; + + + p.q_strideM = head_size; + p.k_strideM = head_size; + p.v_strideM = head_size_v; + + p.q_strideH = p.q_strideM * seq_len; + p.k_strideH = p.k_strideM * seq_len_kv; + p.v_strideH = p.v_strideM * seq_len_kv; + p.o_strideH = head_size_v; + p.q_strideB = p.q_strideH * num_heads; + p.k_strideB = p.k_strideH * num_heads; + p.v_strideB = p.v_strideH * num_heads; + p.o_strideB = head_size_v * seq_len * num_heads; + } + + // launch kernel + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + if (!Attention::check_supported(p)) { + std::cerr << "Kernel does not support these inputs" << std::endl; + return; + } + kernel_fn<<>>(p); + + cudaError_t err = cudaDeviceSynchronize(); + + if (err != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(err); + return; + } + +} + """ +) + + +FUNC_SIGNATURE = jinja2.Template( + """ +void {{func_name}}(void* output, + void* query, + void* key, + void* value, + float* accum_ptr, + int64_t* batch_size, + int seq_len, + int seq_len_kv, + int num_heads, + int head_size, + int head_size_v, + float p_dropout, + float softmax_scale, + bool is_causal, + cudaStream_t stream) + """ +) + +FUNC_DECL = jinja2.Template( + """ + {{func_signature}}; + """ +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{output}}, +{{indent}} {{query}}, {{key}}, {{value}}, +{{indent}} {{accum_ptr}}, +{{indent}} {{batch_size}}, +{{indent}} {{seq_len}}, +{{indent}} {{seq_len_kv}}, +{{indent}} {{num_heads}}, +{{indent}} {{head_size}}, +{{indent}} {{head_size_v}}, +{{indent}} {{p_dropout}}, +{{indent}} {{softmax_scale}}, +{{indent}} {{is_causal}}, stream /* default stream */ +{{indent}}); + """ +) + + +@registry.reg("cuda.mem_eff_attention.gen_function") +def mem_eff_attention_gen_function(func_attrs: Dict[str, Any]) -> str: + """the function for generating attention kernel""" + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + return FUNC_TEMPLATE.render( + elem_input_type=elem_input_type, + func_signature=FUNC_SIGNATURE.render(func_name=func_attrs["name"]), + kIs64x64="true" if func_attrs["head_size"] <= 64 else "false", + kSingleValueIteration="true" if func_attrs["head_size"] <= 128 else "false", + ) + + +@registry.reg("cuda.mem_eff_attention.func_decl") +def mem_eff_attention_gen_function_decl(func_attrs: Dict[str, Any]): + return FUNC_DECL.render( + func_signature=FUNC_SIGNATURE.render(func_name=func_attrs["name"]).strip() + ) + + +@registry.reg("cuda.mem_eff_attention.func_call") +def mem_eff_attention_gen_function_call(func_attrs, indent=" "): + """the function for generating a function call for attention""" + output_name = "" + assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["inputs"]) == 3 + + output_name = func_attrs["outputs"][0]._attrs["name"] + + q_name = func_attrs["inputs"][0]._attrs["name"] + k_name = func_attrs["inputs"][1]._attrs["name"] + v_name = func_attrs["inputs"][2]._attrs["name"] + + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + batch_size = "&" + xshape[0]._attrs["name"] + seq_len = x._attrs["shape"][2]._attrs["values"][0] + + num_heads = x._attrs["shape"][1]._attrs["values"][0] + head_size = x._attrs["shape"][3]._attrs["values"][0] + p_dropout = func_attrs["dropout"] + is_causal = func_attrs["causal"] + softmax_scale = head_size ** (-0.5) + + v = func_attrs["inputs"][2] + seq_len_kv = v._attrs["shape"][2]._attrs["values"][0] + head_size_v = v._attrs["shape"][3]._attrs["values"][0] + + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + output=output_name, + query=q_name, + key=k_name, + value=v_name, + accum_ptr="reinterpret_cast(global_workspace_)", + batch_size=batch_size, + seq_len=seq_len, + seq_len_kv=seq_len_kv, + num_heads=num_heads, + head_size=head_size, + head_size_v=head_size_v, + p_dropout=p_dropout, + softmax_scale=softmax_scale, + is_causal="true" if is_causal else "false", + indent=indent, + ) diff --git a/python/aitemplate/backend/cuda/attention/src/fmha.h b/python/aitemplate/backend/cuda/attention/src/fmha.h index 9cc516722..066f236c7 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/gemm.h b/python/aitemplate/backend/cuda/attention/src/fmha/gemm.h index 433676370..254abe31b 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha/gemm.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha/gemm.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/gmem_tile.h b/python/aitemplate/backend/cuda/attention/src/fmha/gmem_tile.h index 119ac6a6f..fa00d5984 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha/gmem_tile.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha/gmem_tile.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/kernel_traits.h b/python/aitemplate/backend/cuda/attention/src/fmha/kernel_traits.h index 27aad1b80..3b7487e3b 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha/kernel_traits.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha/kernel_traits.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/mask.h b/python/aitemplate/backend/cuda/attention/src/fmha/mask.h index ec07012af..358acb90a 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha/mask.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha/mask.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/smem_tile.h b/python/aitemplate/backend/cuda/attention/src/fmha/smem_tile.h index 0bb8285d2..c3f87a71d 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha/smem_tile.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha/smem_tile.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/softmax.h b/python/aitemplate/backend/cuda/attention/src/fmha/softmax.h index 02e82c427..ec5461966 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha/softmax.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha/softmax.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha/utils.h b/python/aitemplate/backend/cuda/attention/src/fmha/utils.h index 7bc0b3df9..4a95ccce6 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha/utils.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha/utils.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_fp16_kernel.sm80.cu b/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_fp16_kernel.sm80.cu index 46bddc48e..92756cc6f 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_fp16_kernel.sm80.cu +++ b/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_fp16_kernel.sm80.cu @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_kernel_1xN.h b/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_kernel_1xN.h index 89776414a..d90ab5065 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_kernel_1xN.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha_block_fprop_kernel_1xN.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /*************************************************************************************************** * Copyright (c) 2022, Tri Dao. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_blockmask.h b/python/aitemplate/backend/cuda/attention/src/fmha_blockmask.h index 9de497e7f..94dd66718 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha_blockmask.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha_blockmask.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_fprop_fp16_kernel.sm80.cu b/python/aitemplate/backend/cuda/attention/src/fmha_fprop_fp16_kernel.sm80.cu index 5031d81a0..aa4138983 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/python/aitemplate/backend/cuda/attention/src/fmha_fprop_fp16_kernel.sm80.cu @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_fprop_kernel_1xN.h b/python/aitemplate/backend/cuda/attention/src/fmha_fprop_kernel_1xN.h index 1cd4c191c..86f39f3c7 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha_fprop_kernel_1xN.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha_fprop_kernel_1xN.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /*************************************************************************************************** * Copyright (c) 2022, Tri Dao. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_kernel.h b/python/aitemplate/backend/cuda/attention/src/fmha_kernel.h index 43692802b..41f49ffda 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha_kernel.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha_kernel.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/fmha_utils.h b/python/aitemplate/backend/cuda/attention/src/fmha_utils.h index af8456621..a27bd40d9 100644 --- a/python/aitemplate/backend/cuda/attention/src/fmha_utils.h +++ b/python/aitemplate/backend/cuda/attention/src/fmha_utils.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// /****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * diff --git a/python/aitemplate/backend/cuda/attention/src/philox.cuh b/python/aitemplate/backend/cuda/attention/src/philox.cuh index 36e788400..4ab1a63ff 100644 --- a/python/aitemplate/backend/cuda/attention/src/philox.cuh +++ b/python/aitemplate/backend/cuda/attention/src/philox.cuh @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// // Pytorch also has an implementation of Philox RNG: // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu #pragma once diff --git a/python/aitemplate/backend/cuda/conv2d/common.py b/python/aitemplate/backend/cuda/conv2d/common.py index 9e0de0d91..61c6c05f8 100644 --- a/python/aitemplate/backend/cuda/conv2d/common.py +++ b/python/aitemplate/backend/cuda/conv2d/common.py @@ -16,12 +16,15 @@ common template for conv2d """ import re + from collections import OrderedDict from hashlib import sha1 from typing import List import jinja2 +from aitemplate.backend.backend_spec import CUDASpec + from ...target import Target from ..gemm_universal.common import add_profiler, build_profiler # noqa: F401 @@ -153,15 +156,19 @@ def gen_function( inst_def_flag = set() instances = {} instance_decl = "" + backend_spec = CUDASpec() + dtype = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) for key, value in exec_path.items(): fname = "f" + sha1(key.encode()).hexdigest() + + emit_instance = f_emit_instance(op_instance[value]) if value not in inst_def_flag: - config = f_emit_instance(op_instance[value]) inst_def_flag.add(value) + config = emit_instance else: config = "" inst = instance_template.render( - config=config, name=fname, config_name=extract_config_name(config) + config=config, name=fname, config_name=extract_config_name(emit_instance) ) instances[key] = inst instance_decl += inst @@ -191,13 +198,16 @@ def gen_function( exec_paths = "" for key in instances: fname = "f" + sha1(key.encode()).hexdigest() - program = exec_template.render(indent=" ", instance=fname) + program = exec_template.render( + indent=" " * 4, + instance=fname, + dtype=dtype, + ) exec_inst = exec_cond_remplate.render(indent=" ", cond=key, program=program) exec_paths += exec_inst return src_template.render( instances=instance_decl, function_name=func_name, - dtype="cutlass::half_t", shape_function=shape_func, exec_paths=exec_paths, extra_header=extra_header, diff --git a/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_activation.py b/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_activation.py index ddcef02b3..aa48d92f9 100644 --- a/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_activation.py +++ b/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_activation.py @@ -17,6 +17,8 @@ """ import jinja2 +from aitemplate.backend.backend_spec import CUDASpec + from . import common # pylint: disable=C0103,C0301 @@ -34,10 +36,10 @@ // TODO: cast to right dtype {{indent}}typename {{instance}}::Arguments arguments{ {{indent}} problem_size, -{{indent}} {(cutlass::half_t*)(in_ptr), layout_A}, -{{indent}} {(cutlass::half_t*)(weight_ptr), layout_B}, -{{indent}} {(cutlass::half_t*)(bias_ptr), cutlass::layout::TensorNHWC::Stride(0)}, -{{indent}} {(cutlass::half_t*)(out_ptr), layout_C}, +{{indent}} {static_cast<{{dtype}}*>(in_ptr), layout_A}, +{{indent}} {static_cast<{{dtype}}*>(weight_ptr), layout_B}, +{{indent}} {static_cast<{{dtype}}*>(bias_ptr), cutlass::layout::TensorNHWC::Stride(0)}, +{{indent}} {static_cast<{{dtype}}*>(out_ptr), layout_C}, {{indent}} {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, {{indent}}}; {{indent}}{{instance}} implicit_gemm_op; @@ -89,10 +91,10 @@ {{instances_def}} void {{function_name}} ( - cutlass::half_t* in_ptr, - cutlass::half_t* weight_ptr, - cutlass::half_t* out_ptr, - cutlass::half_t* bias_ptr, + void* in_ptr, + void* weight_ptr, + void* out_ptr, + void* bias_ptr, uint8_t* workspace, int64_t* batch, int64_t* out_ch, @@ -177,10 +179,10 @@ cutlass::HostTensor y({NO, HO, WO, CO}); // // warmup - conv((cutlass::half_t*) x.device_data(), - (cutlass::half_t*) w.device_data(), - (cutlass::half_t*) y.device_data(), - (cutlass::half_t*) b.device_data(), + conv(x.device_data(), + w.device_data(), + y.device_data(), + b.device_data(), global_workspace, &NI, &CO, @@ -200,12 +202,12 @@ for (auto & event : events) { cudaEventCreate(&event); } - cudaEventRecord(events[0]); + cudaEventRecord(events[0], stream); for (int i = 0; i < 5; ++i) { - conv((cutlass::half_t*) x.device_data(), - (cutlass::half_t*) w.device_data(), - (cutlass::half_t*) y.device_data(), - (cutlass::half_t*) b.device_data(), + conv(x.device_data(), + w.device_data(), + y.device_data(), + b.device_data(), global_workspace, &NI, &CO, @@ -222,7 +224,7 @@ pad, stream); } - cudaEventRecord(events[1]); + cudaEventRecord(events[1], stream); cudaEventSynchronize(events[1]); float runtime_ms = 0; cudaEventElapsedTime(&runtime_ms, events[0], events[1]); @@ -245,10 +247,10 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, + void*, + void*, uint8_t*, int64_t*, int64_t*, @@ -275,7 +277,7 @@ {{indent}} {{weight_ptr}}, {{indent}} {{out_ptr}}, {{indent}} {{bias_ptr}}, -{{indent}} global_workspace, +{{indent}} global_workspace_, {{indent}} {{p_batch}}, {{indent}} {{p_out_ch}}, {{indent}} {{p_in_ch}}, @@ -314,6 +316,9 @@ def gen_profiler(func_attrs, workdir, shape_template, extra_header=""): dilate="dilation", pad="pad", ) + + backend_spec = CUDASpec() + dtype = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) file_pairs = [] for op_name, op in op_instance.items(): config = common.emit_instance(op) @@ -324,12 +329,14 @@ def gen_profiler(func_attrs, workdir, shape_template, extra_header=""): config_name=config_name, name=name, config=config ) exec_program = EXEC_TEMPLATE.render( - indent=" ", is_profiler=True, instance=name + indent=" ", + is_profiler=True, + instance=name, + dtype=dtype, ) op_func = SRC_TEMPLATE.render( instances=instance, function_name="conv", - dtype="cutlass::half_t", shape_func="", exec_paths=exec_program, extra_header=extra_header, @@ -339,7 +346,7 @@ def gen_profiler(func_attrs, workdir, shape_template, extra_header=""): ) common.add_profiler(file_pairs, workdir, op_type, op_name, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) def gen_function_call(func_attrs, indent=" "): diff --git a/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py b/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py index 0647769a1..5439f1fc0 100644 --- a/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py +++ b/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py @@ -17,6 +17,8 @@ """ import jinja2 +from aitemplate.backend.backend_spec import CUDASpec + from . import common # pylint: disable=C0301,C0103 @@ -34,13 +36,13 @@ // TODO: cast to right dtype {{indent}}typename {{instance}}::Arguments arguments{ {{indent}} problem_size, -{{indent}} {(cutlass::half_t*)(in_ptr), layout_A}, -{{indent}} {(cutlass::half_t*)(weight_ptr), layout_B}, -{{indent}} {(cutlass::half_t*)(res_ptr), layout_C}, -{{indent}} {(cutlass::half_t*)(out_ptr), layout_C}, +{{indent}} {static_cast<{{dtype}}*>(in_ptr), layout_A}, +{{indent}} {static_cast<{{dtype}}*>(weight_ptr), layout_B}, +{{indent}} {static_cast<{{dtype}}*>(res_ptr), layout_C}, +{{indent}} {static_cast<{{dtype}}*>(out_ptr), layout_C}, {{indent}} {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, {{indent}} cutlass::conv::SplitKMode::kSerial, -{{indent}} (cutlass::half_t*)(bias_ptr), +{{indent}} static_cast<{{dtype}}*>(bias_ptr), {{indent}} nullptr, 0, *out_ch {{indent}}}; {{indent}}{{instance}} implicit_gemm_op; @@ -90,11 +92,11 @@ {{instances_def}} void {{function_name}} ( - cutlass::half_t* in_ptr, - cutlass::half_t* weight_ptr, - cutlass::half_t* out_ptr, - cutlass::half_t* bias_ptr, - cutlass::half_t* res_ptr, + void* in_ptr, + void* weight_ptr, + void* out_ptr, + void* bias_ptr, + void* res_ptr, uint8_t* workspace, int64_t* batch, int64_t* out_ch, @@ -180,11 +182,11 @@ cutlass::HostTensor y({NO, HO, WO, CO}); // // warmup - conv((cutlass::half_t*) x.device_data(), - (cutlass::half_t*) w.device_data(), - (cutlass::half_t*) y.device_data(), - (cutlass::half_t*) b.device_data(), - (cutlass::half_t*) r.device_data(), + conv(x.device_data(), + w.device_data(), + y.device_data(), + b.device_data(), + r.device_data(), global_workspace, &NI, &CO, @@ -204,13 +206,13 @@ for (auto & event : events) { cudaEventCreate(&event); } - cudaEventRecord(events[0]); + cudaEventRecord(events[0], stream); for (int i = 0; i < 5; ++i) { - conv((cutlass::half_t*) x.device_data(), - (cutlass::half_t*) w.device_data(), - (cutlass::half_t*) y.device_data(), - (cutlass::half_t*) b.device_data(), - (cutlass::half_t*) r.device_data(), + conv(x.device_data(), + w.device_data(), + y.device_data(), + b.device_data(), + r.device_data(), global_workspace, &NI, &CO, @@ -227,7 +229,7 @@ pad, stream); } - cudaEventRecord(events[1]); + cudaEventRecord(events[1], stream); cudaEventSynchronize(events[1]); float runtime_ms = 0; cudaEventElapsedTime(&runtime_ms, events[0], events[1]); @@ -251,11 +253,11 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, + void*, + void*, + void*, uint8_t*, int64_t*, int64_t*, @@ -283,7 +285,7 @@ {{indent}} {{out_ptr}}, {{indent}} {{bias_ptr}}, {{indent}} {{res_ptr}}, -{{indent}} global_workspace, +{{indent}} global_workspace_, {{indent}} {{p_batch}}, {{indent}} {{p_out_ch}}, {{indent}} {{p_in_ch}}, @@ -322,6 +324,8 @@ def gen_profiler(func_attrs, workdir, shape_template): dilate="dilation", pad="pad", ) + backend_spec = CUDASpec() + dtype = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) file_pairs = [] for op_name, op in op_instance.items(): config = common.emit_instance(op) @@ -331,12 +335,11 @@ def gen_profiler(func_attrs, workdir, shape_template): config_name=config_name, name=name, config=config ) exec_program = EXEC_TEMPLATE.render( - indent=" ", is_profiler=True, instance=name + indent=" ", is_profiler=True, instance=name, dtype=dtype ) op_func = SRC_TEMPLATE.render( instances=instance, function_name="conv", - dtype="cutlass::half_t", shape_func="", exec_paths=exec_program, ) @@ -345,4 +348,4 @@ def gen_profiler(func_attrs, workdir, shape_template): ) common.add_profiler(file_pairs, workdir, op_type, op_name, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d.py b/python/aitemplate/backend/cuda/conv2d/conv2d.py index 7e5da403f..3279e2ff7 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d.py @@ -17,6 +17,8 @@ """ import jinja2 +from aitemplate.backend.backend_spec import CUDASpec + from ... import registry from . import common @@ -35,10 +37,10 @@ // TODO: cast to right dtype {{indent}}typename {{instance}}::Arguments arguments{ {{indent}} problem_size, -{{indent}} {(cutlass::half_t*)(in_ptr), layout_A}, -{{indent}} {(cutlass::half_t*)(weight_ptr), layout_B}, -{{indent}} {(cutlass::half_t*)(out_ptr), layout_C}, -{{indent}} {(cutlass::half_t*)(out_ptr), layout_C}, +{{indent}} {static_cast<{{dtype}}*>(in_ptr), layout_A}, +{{indent}} {static_cast<{{dtype}}*>(weight_ptr), layout_B}, +{{indent}} {static_cast<{{dtype}}*>(out_ptr), layout_C}, +{{indent}} {static_cast<{{dtype}}*>(out_ptr), layout_C}, {{indent}} {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, {{indent}}}; {{indent}}{{instance}} implicit_gemm_op; @@ -87,9 +89,9 @@ {{instances_def}} void {{function_name}} ( - cutlass::half_t* in_ptr, - cutlass::half_t* weight_ptr, - cutlass::half_t* out_ptr, + void* in_ptr, + void* weight_ptr, + void* out_ptr, uint8_t* workspace, int64_t* batch, int64_t* out_ch, @@ -175,9 +177,9 @@ // // warmup - conv((cutlass::half_t*) x.device_data(), - (cutlass::half_t*) w.device_data(), - (cutlass::half_t*) y.device_data(), + conv(x.device_data(), + w.device_data(), + y.device_data(), global_workspace, &NI, &CO, @@ -197,11 +199,11 @@ for (auto & event : events) { cudaEventCreate(&event); } - cudaEventRecord(events[0]); + cudaEventRecord(events[0], stream); for (int i = 0; i < 5; ++i) { - conv((cutlass::half_t*) x.device_data(), - (cutlass::half_t*) w.device_data(), - (cutlass::half_t*) y.device_data(), + conv(x.device_data(), + w.device_data(), + y.device_data(), global_workspace, &NI, &CO, @@ -218,7 +220,7 @@ pad, stream); } - cudaEventRecord(events[1]); + cudaEventRecord(events[1], stream); cudaEventSynchronize(events[1]); float runtime_ms = 0; cudaEventElapsedTime(&runtime_ms, events[0], events[1]); @@ -241,9 +243,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, + void*, uint8_t*, int64_t*, int64_t*, @@ -269,7 +271,7 @@ {{indent}} {{in_ptr}}, {{indent}} {{weight_ptr}}, {{indent}} {{out_ptr}}, -{{indent}} global_workspace, +{{indent}} global_workspace_, {{indent}} {{p_batch}}, {{indent}} {{p_out_ch}}, {{indent}} {{p_in_ch}}, @@ -317,6 +319,8 @@ def gen_profiler(func_attrs, workdir, shape_template): pad="pad", ) file_pairs = [] + backend_spec = CUDASpec() + dtype = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) for op_name, op in op_instance.items(): config = common.emit_instance(op) config_name = common.extract_config_name(config) @@ -325,12 +329,11 @@ def gen_profiler(func_attrs, workdir, shape_template): config_name=config_name, name=name, config=config ) exec_program = EXEC_TEMPLATE.render( - indent=" ", is_profiler=True, instance=name + indent=" ", is_profiler=True, instance=name, dtype=dtype ) op_func = SRC_TEMPLATE.render( instances=instance, function_name="conv", - dtype="cutlass::half_t", shape_func="", exec_paths=exec_program, ) @@ -339,7 +342,7 @@ def gen_profiler(func_attrs, workdir, shape_template): ) common.add_profiler(file_pairs, workdir, op_type, op_name, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) @registry.reg("cuda.conv2d.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py index c1ce2ac94..c4fb32c42 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py @@ -30,7 +30,7 @@ def conv2d_config(func_attrs, dtype="float16"): @registry.reg("cuda.conv2d_bias.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): """Codegen for conv2d profiler.""" - cba.gen_profiler(func_attrs, workdir, shape_template) + return cba.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py index 663495f22..07ecbbff6 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py @@ -67,7 +67,7 @@ def fproc_f16(op): @registry.reg("cuda.conv2d_bias_add_identity.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): - cbaa.gen_profiler(func_attrs, workdir, shape_template) + return cbaa.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_add_identity.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py index 10aa46619..09d975ae4 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py @@ -67,7 +67,7 @@ def fproc_f16(op): @registry.reg("cuda.conv2d_bias_add_hardswish.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): - cbaa.gen_profiler(func_attrs, workdir, shape_template) + return cbaa.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_add_hardswish.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py index b6b96704f..5a5e7314b 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py @@ -67,7 +67,7 @@ def fproc_f16(op): @registry.reg("cuda.conv2d_bias_add_relu.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): - cbaa.gen_profiler(func_attrs, workdir, shape_template) + return cbaa.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_add_relu.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_few_channels.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_few_channels.py index b8ddfa205..584eddbfe 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_few_channels.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_few_channels.py @@ -140,7 +140,7 @@ def conv2d_config(func_attrs, dtype="float16"): @registry.reg("cuda.conv2d_bias_few_channels.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): """generate code for profiling""" - cba.gen_profiler(func_attrs, workdir, shape_template) + return cba.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_few_channels.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py index e31ad9095..ccdc3ae1e 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py @@ -28,7 +28,7 @@ def conv2d_config(func_attrs, dtype="float16"): @registry.reg("cuda.conv2d_bias_hardswish.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): - cba.gen_profiler(func_attrs, workdir, shape_template) + return cba.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_hardswish.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish_few_channels.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish_few_channels.py index f305f3344..f8de585fa 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish_few_channels.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish_few_channels.py @@ -52,7 +52,7 @@ def conv2d_config(func_attrs, dtype="float16"): @registry.reg("cuda.conv2d_bias_hardswish_few_channels.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): """generate code for profiling""" - cba.gen_profiler(func_attrs, workdir, shape_template) + return cba.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_hardswish_few_channels.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py index ea75bdd9d..920e13d5c 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py @@ -28,7 +28,7 @@ def conv2d_config(func_attrs, dtype="float16"): @registry.reg("cuda.conv2d_bias_relu.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): - cba.gen_profiler(func_attrs, workdir, shape_template) + return cba.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_relu.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu_few_channels.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu_few_channels.py index e207bc10a..39019c5f1 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu_few_channels.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu_few_channels.py @@ -44,7 +44,7 @@ def conv2d_config(func_attrs, dtype="float16"): @registry.reg("cuda.conv2d_bias_relu_few_channels.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): """generate code for profiling""" - cba.gen_profiler(func_attrs, workdir, shape_template) + return cba.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_relu_few_channels.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_sigmoid.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_sigmoid.py index 5ad4ccd6a..cbb896e71 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_sigmoid.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_sigmoid.py @@ -29,7 +29,7 @@ def conv2d_config(func_attrs, dtype="float16"): @registry.reg("cuda.conv2d_bias_sigmoid.gen_profiler") def gen_profiler(func_attrs, workdir, shape_template): - cba.gen_profiler(func_attrs, workdir, shape_template) + return cba.gen_profiler(func_attrs, workdir, shape_template) @registry.reg("cuda.conv2d_bias_sigmoid.gen_function") diff --git a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py index b1b6acbc1..574f0d361 100644 --- a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py +++ b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py @@ -19,6 +19,8 @@ import jinja2 +from aitemplate.backend.backend_spec import CUDASpec + from ... import registry from . import common, conv2d @@ -51,9 +53,9 @@ {{instances_def}} void {{function_name}} ( - cutlass::half_t* in_ptr, - cutlass::half_t* weight_ptr, - cutlass::half_t* out_ptr, + void* in_ptr, + void* weight_ptr, + void* out_ptr, uint8_t* workspace, int64_t* batch, int64_t* out_ch, @@ -209,6 +211,9 @@ def gen_profiler(func_attrs, workdir, shape_template): pad="pad", ) file_pairs = [] + + backend_spec = CUDASpec() + dtype = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) for op_name, op in op_instance.items(): config = emit_instance(op) @@ -218,12 +223,11 @@ def gen_profiler(func_attrs, workdir, shape_template): config_name=config_name, name=name, config=config ) exec_program = conv2d.EXEC_TEMPLATE.render( - indent=" ", is_profiler=True, instance=name + indent=" ", is_profiler=True, instance=name, dtype=dtype ) op_func = SRC_TEMPLATE.render( instances=instance, function_name="conv", - dtype="cutlass::half_t", shape_func="", exec_paths=exec_program, ) @@ -232,7 +236,7 @@ def gen_profiler(func_attrs, workdir, shape_template): ) common.add_profiler(file_pairs, workdir, op_type, op_name, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) @registry.reg("cuda.transposed_conv2d.filter") diff --git a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py index 2df9642fa..35b08d19f 100644 --- a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py +++ b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py @@ -19,6 +19,8 @@ import jinja2 +from aitemplate.backend.backend_spec import CUDASpec + from ... import registry from . import common, common_conv2d_bias_activation as cba @@ -51,10 +53,10 @@ {{instances_def}} void {{function_name}} ( - cutlass::half_t* in_ptr, - cutlass::half_t* weight_ptr, - cutlass::half_t* out_ptr, - cutlass::half_t* bias_ptr, + void* in_ptr, + void* weight_ptr, + void* out_ptr, + void* bias_ptr, uint8_t* workspace, int64_t* batch, int64_t* out_ch, @@ -215,6 +217,8 @@ def gen_profiler(func_attrs, workdir, shape_template): dilate="dilation", pad="pad", ) + backend_spec = CUDASpec() + dtype = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) file_pairs = [] for op_name, op in op_instance.items(): config = emit_instance(op) @@ -225,12 +229,11 @@ def gen_profiler(func_attrs, workdir, shape_template): config_name=config_name, name=name, config=config ) exec_program = cba.EXEC_TEMPLATE.render( - indent=" ", is_profiler=True, instance=name + indent=" ", is_profiler=True, instance=name, dtype=dtype ) op_func = SRC_TEMPLATE.render( instances=instance, function_name="conv", - dtype="cutlass::half_t", shape_func="", exec_paths=exec_program, ) @@ -239,7 +242,7 @@ def gen_profiler(func_attrs, workdir, shape_template): ) common.add_profiler(file_pairs, workdir, op_type, op_name, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) @registry.reg("cuda.transposed_conv2d_bias.filter") diff --git a/python/aitemplate/backend/cuda/conv3d/__init__.py b/python/aitemplate/backend/cuda/conv3d/__init__.py new file mode 100644 index 000000000..ba1388ae4 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv3d/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +CUDA conv3d module init +""" +from . import conv3d, depthwise_conv3d + +__all__ = ["conv3d", "depthwise_conv3d"] diff --git a/python/aitemplate/backend/cuda/conv3d/common.py b/python/aitemplate/backend/cuda/conv3d/common.py new file mode 100644 index 000000000..461c4e6e9 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv3d/common.py @@ -0,0 +1,364 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +CUDA conv3d common functions +""" +import re +from collections import OrderedDict +from hashlib import sha1 +from typing import List + +import jinja2 + +from aitemplate.backend.backend_spec import CUDASpec + +from ...target import Target +from ..gemm_universal.common import add_profiler, build_profiler # noqa: F401 + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + void*, + void*, + void*, + int64_t*, // kernel size + int64_t*, + int64_t*, + int, // strides + int, + int, + int, // padding + int, + int, + int, // dilation + int, + int, + int64_t*, // in_batch + int64_t*, // in_ch + int64_t*, // in_t + int64_t*, // in_h + int64_t*, // in_w + int64_t*, // out_ch + int64_t*, // out_t + int64_t*, // out_h + int64_t*, // out_w + cudaStream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{in_ptr}}, +{{indent}} {{weight_ptr}}, +{{indent}} {{out_ptr}}, +{{indent}} {{p_kernel_t}}, +{{indent}} {{p_kernel_h}}, +{{indent}} {{p_kernel_w}}, +{{indent}} {{stride_t}}, +{{indent}} {{stride_h}}, +{{indent}} {{stride_w}}, +{{indent}} {{padding_t}}, +{{indent}} {{padding_h}}, +{{indent}} {{padding_w}}, +{{indent}} {{dilation_t}}, +{{indent}} {{dilation_h}}, +{{indent}} {{dilation_w}}, +{{indent}} {{p_in_batch}}, +{{indent}} {{p_in_ch}}, +{{indent}} {{p_in_t}}, +{{indent}} {{p_in_h}}, +{{indent}} {{p_in_w}}, +{{indent}} {{p_out_ch}}, +{{indent}} {{p_out_t}}, +{{indent}} {{p_out_h}}, +{{indent}} {{p_out_w}}, +{{indent}} stream +{{indent}}); +""" +) + + +def gen_function_decl(func_name): + return FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +def gen_function_call(func_attrs, indent=" "): + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + p_in_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[4]._attrs["name"], + p_kernel_t="&" + wshape[1]._attrs["name"], + p_kernel_h="&" + wshape[2]._attrs["name"], + p_kernel_w="&" + wshape[3]._attrs["name"], + p_in_t="&" + xshape[1]._attrs["name"], + p_in_h="&" + xshape[2]._attrs["name"], + p_in_w="&" + xshape[3]._attrs["name"], + p_out_t="&" + yshape[1]._attrs["name"], + p_out_h="&" + yshape[2]._attrs["name"], + p_out_w="&" + yshape[3]._attrs["name"], + stride_t=func_attrs["stride"][0], + stride_h=func_attrs["stride"][1], + stride_w=func_attrs["stride"][2], + padding_t=func_attrs["pad"][0], + padding_h=func_attrs["pad"][1], + padding_w=func_attrs["pad"][2], + dilation_t=func_attrs["dilate"][0], + dilation_h=func_attrs["dilate"][1], + dilation_w=func_attrs["dilate"][2], + indent=indent, + ) + + +KERNEL_KEY_TEMPLATE = jinja2.Template( + """ +cutlass{{opcode_class}}_{{extended_name}}_{{threadblock}}_{{layout}}_align_{{align_ab}}_{{align_c}} +""" +) + + +def kernel_name(op): + """generate cuda kernel name""" + from cutlass_lib import library + + threadblock = op.tile_description.procedural_name() + extended_name = op.extended_name() + opcode_class_name = library.OpcodeClassNames[ + op.tile_description.math_instruction.opcode_class + ] + layout = "ndhwc" # op.layout_name() + align_ab = op.A.alignment + align_c = op.C.alignment + name = KERNEL_KEY_TEMPLATE.render( + threadblock=threadblock, + extended_name=extended_name, + opcode_class_name=opcode_class_name, + layout=layout, + align_ab=align_ab, + align_c=align_c, + ) + return name.replace("\n", "") + + +def emit_instance(op): + """emit instance""" + import cutlass_lib + + # if hasattr(op, "binary_op"): + # emiter = cutlass_lib.conv3d_operation.EmitConv3dWithBroadcastInstance() + # else: + # emiter = cutlass_lib.conv3d_operation.EmitConv3dInstance() + emiter = cutlass_lib.conv3d_operation.EmitConv3dInstance() + op_def = emiter.emit(op) + return op_def + + +def extract_config(func_attrs, f_proc_op=None): + """Extracts cutlass config for conv kernels.""" + import copy + + import cutlass_lib + + def f_proc_op_default(op): + # import cutlass_lib + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.iterator_algorithm == cutlass_lib.library.IteratorAlgorithm.Optimized + and op.tile_description.math_instruction.element_accumulator == acc_type + ): + + op = copy.deepcopy(op) + # set epilogue + epilogue_name = func_attrs["epilogue"] + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epilogue_name] + op.element_epilogue = acc_type + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + op_kind = cutlass_lib.library.OperationKind.Conv3d + conv_kind = cutlass_lib.library.ConvKind.Fprop + ret = [] + conv3d_ops = OrderedDict() + extract_ops = list(Target.current()._operators[op_kind].items()) + + for _, value in extract_ops: + op = value[0] + if op.conv_kind == conv_kind: + if f_proc_op is None: + ret = f_proc_op_default(op) + else: + ret = f_proc_op(op) + if len(ret) > 0: + for op_inst in ret: + key = kernel_name(op_inst) + conv3d_ops[key] = op_inst + + return conv3d_ops + + +def extract_config_name(config): + """Extracts config name from a given config.""" + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = config.split("\n")[2] + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid config: \n" + config) + return match.groups()[0] + + +def gen_function( + func_attrs, + instance_template, + exec_template, + src_template, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + f_emit_instance=emit_instance, + extra_header="", +): + """Function definition codegen.""" + func_name = func_attrs["name"] + exec_path = func_attrs["exec_path"] + op_instance = func_attrs["op_instance"] + + backend_spec = CUDASpec() + dtype = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) + + inst_def_flag = set() + instances = {} + instance_decl = "" + for key, value in exec_path.items(): + fname = "f" + sha1(key.encode()).hexdigest() + if value not in inst_def_flag: + config = f_emit_instance(op_instance[value]) + inst_def_flag.add(value) + else: + config = "" + inst = instance_template.render( + config=config, name=fname, config_name=extract_config_name(config) + ) + instances[key] = inst + instance_decl += inst + shape_eval_func = shape_eval_template.render( + indent=" ", + dtype="int64_t ", + x_dim0="*batch", + x_dim1="*in_d", + x_dim2="*in_h", + x_dim3="*in_w", + x_dim4="*in_ch", + w_dim0="*out_ch", + w_dim1="*kernel_d", + w_dim2="*kernel_h", + w_dim3="*kernel_w", + stride_d="stride_d", + stride_h="stride_h", + stride_w="stride_w", + dilate_d="dilation_d", + dilate_h="dilation_h", + dilate_w="dilation_w", + pad_d="pad_d", + pad_h="pad_h", + pad_w="pad_w", + div="/", + ) + shape_save_func = shape_save_template.render( + indent=" ", + y_dim0="*out_batch", + y_dim1="*out_d", + y_dim2="*out_h", + y_dim3="*out_w", + y_dim4="*out_ch", + ) + shape_func = shape_eval_func + shape_save_func + exec_paths = "" + for key in instances: + fname = "f" + sha1(key.encode()).hexdigest() + program = exec_template.render(indent=" ", instance=fname, dtype=dtype) + exec_inst = exec_cond_remplate.render(indent=" ", cond=key, program=program) + exec_paths += exec_inst + return src_template.render( + instances=instance_decl, + function_name=func_name, + shape_function=shape_func, + exec_paths=exec_paths, + extra_header=extra_header, + ) + + +def cal_align_ab(x_shape: List[int]) -> int: + """Returns input alignment.""" + k = x_shape[4] # CI + if k % 8 == 0: + return 8 + if k % 4 == 0: + return 4 + if k % 2 == 0: + return 2 + raise RuntimeError("a/b is not aligned") + + +def function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + ab_alignment = cal_align_ab(x_shape) + tmp = cfg.split("_") + align_c = int(tmp[-1]) + align_ab = int(tmp[-2]) + if align_c != func_attrs["epilogue_alignment"]: + return False + if align_ab != ab_alignment: + return False + return True diff --git a/python/aitemplate/backend/cuda/conv3d/conv3d.py b/python/aitemplate/backend/cuda/conv3d/conv3d.py new file mode 100644 index 000000000..045092131 --- /dev/null +++ b/python/aitemplate/backend/cuda/conv3d/conv3d.py @@ -0,0 +1,496 @@ +# Copyright (c) Meta Platform, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen for conv3d. +""" +import jinja2 + +from aitemplate.backend.backend_spec import CUDASpec + +from ... import registry +from . import common + +# pylint: disable=C0103,C0415,W0613,C0301 + +INSTANCE_TEMPLATE = jinja2.Template( + """ +{{config}} +using {{name}} = cutlass::conv::device::ImplicitGemmConvolution<{{config_name}}>; +""" +) + +EXEC_TEMPLATE = jinja2.Template( + """ +{{indent}}using ElementComputeEpilogue = typename {{instance}}::ElementCompute; +// TODO: cast to right dtype +{{indent}}typename {{instance}}::Arguments arguments{ +{{indent}} problem_size, +{{indent}} {static_cast<{{dtype}}*>(in_ptr), layout_A}, +{{indent}} {static_cast<{{dtype}}*>(weight_ptr), layout_B}, +{{indent}} {static_cast<{{dtype}}*>(out_ptr), layout_C}, +{{indent}} {static_cast<{{dtype}}*>(out_ptr), layout_C}, +{{indent}} {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, +{{indent}}}; +{{indent}}{{instance}} implicit_gemm_op; +{% if is_profiler %} +{{indent}}size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); +{{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); +{{indent}}workspace = local_workspace.get(); +{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% endif %} +{{indent}}auto status = implicit_gemm_op.can_implement(arguments); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = implicit_gemm_op.initialize(arguments, workspace); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = implicit_gemm_op(stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}return; +""" +) + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv3d_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +{{extra_header}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +{{instances_def}} + +void {{function_name}} ( + void* in_ptr, + void* weight_ptr, + void* out_ptr, + uint8_t* workspace, + int64_t* batch, + int64_t* out_ch, + int64_t* in_ch, + int64_t* kernel_d, + int64_t* kernel_h, + int64_t* kernel_w, + int64_t* in_d, + int64_t* in_h, + int64_t* in_w, + int64_t* out_batch, + int64_t* out_d, + int64_t* out_h, + int64_t* out_w, + int stride_d, + int stride_h, + int stride_w, + int dilation_d, + int dilation_h, + int dilation_w, + int pad_d, + int pad_h, + int pad_w, + cudaStream_t stream + ) { + + {{shape_function}} + int i32_batch = *batch; + int i32_in_d = *in_d; + int i32_in_h = *in_h; + int i32_in_w = *in_w; + int i32_in_ch = *in_ch; + int i32_out_ch = *out_ch; + int i32_kernel_d = *kernel_d; + int i32_kernel_h = *kernel_h; + int i32_kernel_w = *kernel_w; + int i32_out_batch = *out_batch; + int i32_out_d = *out_d; + int i32_out_h = *out_h; + int i32_out_w = *out_w; + + using cutlass::layout::TensorNDHWC; + TensorNDHWC layout_A(TensorNDHWC::packed(cutlass::make_Coord(i32_batch, i32_in_d, i32_in_h, i32_in_w, i32_in_ch))); + TensorNDHWC layout_B(TensorNDHWC::packed(cutlass::make_Coord(i32_out_ch, i32_kernel_d, i32_kernel_h, i32_kernel_w, i32_in_ch))); + TensorNDHWC layout_C(TensorNDHWC::packed(cutlass::make_Coord(i32_out_batch, i32_out_d, i32_out_h, i32_out_w, i32_out_ch))); + + cutlass::conv::Conv3dProblemSize problem_size( + cutlass::Tensor5DCoord(i32_batch, i32_in_d, i32_in_h, i32_in_w, i32_in_ch), + cutlass::Tensor5DCoord(i32_out_ch, i32_kernel_d, i32_kernel_h, i32_kernel_w, i32_in_ch), + cutlass::make_Coord(pad_d, pad_h, pad_w), + cutlass::make_Coord(stride_d, stride_h, stride_w), + cutlass::make_Coord(dilation_d, dilation_h, dilation_w), + cutlass::conv::Mode::kCrossCorrelation, + 1, + 1 + ); + + {{exec_paths}} + throw std::runtime_error( + "Unsupported workload for this conv3d specialization." + ); +} +""" +) + + +PROFILER_TEMPLATE = jinja2.Template( + """ +size_t GLOBAL_WORKSPACE_SIZE = 0; + +{{op_func}} + +int main(int argc, char** argv) { + int64_t batch = std::stoi(argv[1]); + int64_t in_d = std::stoi(argv[2]); + int64_t in_h = std::stoi(argv[3]); + int64_t in_w = std::stoi(argv[4]); + int64_t in_ch = std::stoi(argv[5]); + int64_t kernel_d = std::stoi(argv[6]); + int64_t kernel_h = std::stoi(argv[7]); + int64_t kernel_w = std::stoi(argv[8]); + int64_t out_ch = std::stoi(argv[9]); + int stride_d = std::stoi(argv[10]); + int stride_h = std::stoi(argv[11]); + int stride_w = std::stoi(argv[12]); + int pad_d = std::stoi(argv[13]); + int pad_h = std::stoi(argv[14]); + int pad_w = std::stoi(argv[15]); + int dilation_d = std::stoi(argv[16]); + int dilation_h = std::stoi(argv[17]); + int dilation_w = std::stoi(argv[18]); + {{shape_func}} + using ElementOutput = typename {{name}}::ElementC; + using ElementInputA = typename {{name}}::ElementA; + using ElementInputB = typename {{name}}::ElementB; + + uint8_t* global_workspace = nullptr; + cudaStream_t stream = nullptr; + + cutlass::HostTensor x({NI, DI, HI, WI, CI}); + cutlass::HostTensor w({CO, KD, KH, KW, CI}); + cutlass::HostTensor y({NO, DO, HO, WO, CO}); + + // + // warmup + conv(x.device_data(), + w.device_data(), + y.device_data(), + global_workspace, + &NI, + &CO, + &CI, + &KD, + &KH, + &KW, + &DI, + &HI, + &WI, + &NO, + &DO, + &HO, + &WO, + stride_d, + stride_h, + stride_w, + dilation_d, + dilation_h, + dilation_w, + pad_d, + pad_h, + pad_w, + stream); + cudaEvent_t events[2]; + for (auto & event : events) { + cudaEventCreate(&event); + } + cudaEventRecord(events[0]); + for (int i = 0; i < 5; ++i) { + conv(x.device_data(), + w.device_data(), + y.device_data(), + global_workspace, + &NI, + &CO, + &CI, + &KD, + &KH, + &KW, + &DI, + &HI, + &WI, + &NO, + &DO, + &HO, + &WO, + stride_d, + stride_h, + stride_w, + dilation_d, + dilation_h, + dilation_w, + pad_d, + pad_h, + pad_w, + stream); + } + cudaEventRecord(events[1]); + cudaEventSynchronize(events[1]); + float runtime_ms = 0; + cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + for (auto event : events) { + (void)cudaEventDestroy(event); + } + // TODO: output workspace + if (runtime_ms < 0.00001) { + throw std::runtime_error( + "OOB in cutlass." + ); + } + std::cout << "TIME:" << runtime_ms << std::endl; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; +} + +""" +) + +FUNC_DECL_TEMPLATE = jinja2.Template( + """ +void {{func_name}}( + void*, + void*, + void*, + uint8_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int64_t*, + int, + int, + int, + int, + int, + int, + int, + int, + int, + cudaStream_t +); +""" +) + +FUNC_CALL_TEMPLATE = jinja2.Template( + """ +{{indent}}{{func_name}}( +{{indent}} {{in_ptr}}, +{{indent}} {{weight_ptr}}, +{{indent}} {{out_ptr}}, +{{indent}} global_workspace_, +{{indent}} {{p_batch}}, +{{indent}} {{p_out_ch}}, +{{indent}} {{p_in_ch}}, +{{indent}} {{p_kernel_d}}, +{{indent}} {{p_kernel_h}}, +{{indent}} {{p_kernel_w}}, +{{indent}} {{p_in_d}}, +{{indent}} {{p_in_h}}, +{{indent}} {{p_in_w}}, +{{indent}} {{p_out_batch}}, +{{indent}} {{p_out_d}}, +{{indent}} {{p_out_h}}, +{{indent}} {{p_out_w}}, +{{indent}} {{stride_d}}, +{{indent}} {{stride_h}}, +{{indent}} {{stride_w}}, +{{indent}} {{dilation_d}}, +{{indent}} {{dilation_h}}, +{{indent}} {{dilation_w}}, +{{indent}} {{pad_d}}, +{{indent}} {{pad_h}}, +{{indent}} {{pad_w}}, +{{indent}} stream +{{indent}}); +""" +) + + +@registry.reg("cuda.conv3d.config") +def conv3d_config(func_attrs, dtype="float16"): + """Populates conv3d cutlass configs into 'op_instance' field.""" + func_attrs["op_instance"] = common.extract_config(func_attrs) + + +@registry.reg("cuda.conv3d.gen_profiler") +def gen_profiler(func_attrs, workdir, shape_template): + """Codegen for conv3d profiler.""" + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + # shape func + shape_func = shape_template.render( + indent=" ", + dtype="int64_t ", + div="/", + x_dim0="batch", + x_dim1="in_d", + x_dim2="in_h", + x_dim3="in_w", + x_dim4="in_ch", + w_dim0="out_ch", + w_dim1="kernel_d", + w_dim2="kernel_h", + w_dim3="kernel_w", + stride_d="stride_d", + stride_h="stride_h", + stride_w="stride_w", + dilate_d="dilation_d", + dilate_h="dilation_h", + dilate_w="dilation_w", + pad_d="pad_d", + pad_h="pad_h", + pad_w="pad_w", + ) + backend_spec = CUDASpec() + dtype = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) + file_pairs = [] + for op_name, op in op_instance.items(): + config = common.emit_instance(op) + config_name = common.extract_config_name(config) + name = "DeviceConvFwdInstance" + instance = INSTANCE_TEMPLATE.render( + config_name=config_name, name=name, config=config + ) + exec_program = EXEC_TEMPLATE.render( + indent=" ", is_profiler=True, instance=name, dtype=dtype + ) + op_func = SRC_TEMPLATE.render( + instances=instance, + function_name="conv", + shape_func="", + exec_paths=exec_program, + ) + code = PROFILER_TEMPLATE.render( + op_func=op_func, shape_func=shape_func, name=name + ) + common.add_profiler(file_pairs, workdir, op_type, op_name, code) + # build + return common.build_profiler(file_pairs) + + +@registry.reg("cuda.conv3d.gen_function") +def gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + """Codegen for conv3d function.""" + return common.gen_function( + func_attrs, + INSTANCE_TEMPLATE, + EXEC_TEMPLATE, + SRC_TEMPLATE, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + ) + + +@registry.reg("cuda.conv3d.func_decl") +def conv3d_gen_function_decl(func_attrs): + """Codegen for conv3d function declaration.""" + func_name = func_attrs["name"] + return FUNC_DECL_TEMPLATE.render(func_name=func_name) + + +@registry.reg("cuda.conv3d.func_call") +def conv3d_gen_function_call(func_attrs, indent=" "): + """Codegen for conv3d function call.""" + x = func_attrs["inputs"][0] + xshape = x._attrs["shape"] + w = func_attrs["inputs"][1] + wshape = w._attrs["shape"] + y = func_attrs["outputs"][0] + yshape = y._attrs["shape"] + return FUNC_CALL_TEMPLATE.render( + func_name=func_attrs["name"], + in_ptr=x._attrs["name"], + weight_ptr=w._attrs["name"], + out_ptr=y._attrs["name"], + p_batch="&" + xshape[0]._attrs["name"], + p_out_ch="&" + wshape[0]._attrs["name"], + p_in_ch="&" + xshape[4]._attrs["name"], + p_kernel_d="&" + wshape[1]._attrs["name"], + p_kernel_h="&" + wshape[2]._attrs["name"], + p_kernel_w="&" + wshape[3]._attrs["name"], + p_in_d="&" + xshape[1]._attrs["name"], + p_in_h="&" + xshape[2]._attrs["name"], + p_in_w="&" + xshape[3]._attrs["name"], + p_out_batch="&" + yshape[0]._attrs["name"], + p_out_d="&" + yshape[1]._attrs["name"], + p_out_h="&" + yshape[2]._attrs["name"], + p_out_w="&" + yshape[3]._attrs["name"], + stride_d=func_attrs["stride"][0], + stride_h=func_attrs["stride"][1], + stride_w=func_attrs["stride"][2], + dilation_d=func_attrs["dilate"][0], + dilation_h=func_attrs["dilate"][1], + dilation_w=func_attrs["dilate"][2], + pad_d=func_attrs["pad"][0], + pad_h=func_attrs["pad"][1], + pad_w=func_attrs["pad"][2], + indent=indent, + ) + + +@registry.reg("cuda.conv3d.filter") +def conv3d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, x_shape) diff --git a/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d.py b/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d.py new file mode 100644 index 000000000..92158b6ae --- /dev/null +++ b/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Codegen functions for depthwise_conv3d. +""" +import jinja2 + +from ... import registry +from . import common + +# pylint: disable=C0103,C0415,W0613,C0301,W0612 + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include "cutlass/util/host_tensor.h" + +#include +#include +#include + +namespace { +#define CUDA_KERNEL_LOOP(i, n) \\ + int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \\ + for (int64_t i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x) + +template +__global__ void conv_depthwise3d_cuda_kernel( + const scalar_t * input, + const half* kernel, + scalar_t * output, + int _kT, int _kH, int _kW, + int strideT, int strideH, int strideW, + int paddingT, int paddingH, int paddingW, + int _dilationT, int _dilationH, int _dilationW, + int iC, int iT, int iH, int iW, + int oT, int oH, int oW, + int num_outputs) +{ + int kT = kernel_k > 0? kernel_k: _kT; + int kH = kernel_k > 0? kernel_k: _kH; + int kW = kernel_k > 0? kernel_k: _kW; + + int dilationT = dil_d > 0? dil_d: _dilationT; + int dilationH = dil_d > 0? dil_d: _dilationH; + int dilationW = dil_d > 0? dil_d: _dilationW; + + const int oC = iC; + const int channel_multiplier = 1; + + CUDA_KERNEL_LOOP(index, num_outputs) { + const int out_channel = index % oC; + const int out_col = (index / oC) % oW; + const int out_row = (index / oC / oW) % oH; + const int out_frame = (index / oC / oW / oH) % oT; + const int batch = index / oC / oW / oH / oT; + + const int in_channel = out_channel / channel_multiplier; + + const int in_col_start = out_col * strideW - paddingW; + const int in_row_start = out_row * strideH - paddingH; + const int in_frame_start = out_frame * strideT - paddingT; + + const int in_offset = in_channel + iC * (in_col_start + iW * (in_row_start + iH * (in_frame_start + iT* batch))); + const int out_offset = out_channel + oC * (out_col + oW * (out_row + oH * (out_frame + oT* batch))); + + accscalar_t sum[8]; + for (int tk = 0; tk < element_in_Tio; tk++){ + sum[tk] = 0; + } + const half *kernel_ptr = kernel + out_channel * element_in_Tio * kT * kH * kW; + const scalar_t *input_ptr = input + in_offset; + for (int k_frame = 0; k_frame < kT; ++k_frame) { + const int in_frame = in_frame_start + k_frame * dilationT; + for (int k_row = 0; k_row < kH; ++k_row) { + const int in_row = in_row_start + k_row * dilationH; + for (int k_col = 0; k_col < kW; ++k_col) { + const int in_col = in_col_start + k_col * dilationW; + if (in_frame >= 0 && in_row >= 0 && in_col >= 0 && + in_frame < iT && in_row < iH && in_col < iW) { + scalar_t input_val = __ldg(input_ptr); + Telement* pack_input = reinterpret_cast(&input_val); + + for (int tk = 0; tk < element_in_Tio; tk++){ + accscalar_t op1 = __half2float(pack_input[tk]); + sum[tk] += op1 * __half2float(kernel_ptr[tk*kT*kH*kW]); + } + } + kernel_ptr += 1; + input_ptr += dilationW * iC; + } + input_ptr += iC * (iW * dilationH - kW * dilationW); + } + input_ptr += iC * iW * (iH * dilationT - kH * dilationH); + } + + scalar_t output_val; + Telement* pack_output = reinterpret_cast(&output_val); + for (int tk = 0; tk < element_in_Tio; tk++){ + pack_output[tk] = __float2half(sum[tk]); + } + output[out_offset] = output_val; + } +} + +#define NODEF_OR_EQUAL(x, y) ((y) < 0 || (x) == (y)) +#define NODEF_OR_EQUAL_3(x, y1, y2, y3) \\ + (NODEF_OR_EQUAL(x, y1) && \\ + NODEF_OR_EQUAL(x, y2) && \\ + NODEF_OR_EQUAL(x, y3)) + + +#define DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(kernel_k, dil_d) \\ + if (NODEF_OR_EQUAL_3(kernel_k, (kernel_t), (kernel_h), (kernel_w)) && \\ + NODEF_OR_EQUAL_3(dil_d, (dilation_t), (dilation_h), (dilation_w))) { \\ + conv_depthwise3d_cuda_kernel \\ + \\ + <<>>( \\ + (const scalar_t *)input, \\ + weight, \\ + (scalar_t *)output, \\ + kernel_t, kernel_h, kernel_w, \\ + stride_t, stride_h, stride_w, \\ + padding_t, padding_h, padding_w, \\ + dilation_t, dilation_h, dilation_w, \\ + c, t, h, w, \\ + to, ho, wo, \\ + num_outputs); \\ + } else \\ + +#define DWCONV3D_FORWARD_DISPATCH_OTHERS \\ + { \\ + conv_depthwise3d_cuda_kernel \\ + \\ + <<>>( \\ + (const scalar_t *)input, \\ + weight, \\ + (scalar_t *)output, \\ + kernel_t, kernel_h, kernel_w, \\ + stride_t, stride_h, stride_w, \\ + padding_t, padding_h, padding_w, \\ + dilation_t, dilation_h, dilation_w, \\ + c, t, h, w, \\ + to, ho, wo, \\ + num_outputs);} \\ + + +void conv_depthwise3d_launcher( + const half * input, + const half * weight, + half * output, + int kernel_t, + int kernel_h, + int kernel_w, + int stride_t, + int stride_h, + int stride_w, + int padding_t, + int padding_h, + int padding_w, + int dilation_t, + int dilation_h, + int dilation_w, + int n, + int c, + int t, + int h, + int w, + int to, + int ho, + int wo, + cudaStream_t stream + ) { + + assert(to > 0); + assert(ho > 0); + assert(wo > 0); + + int64_t num_outputs = n * to * ho * wo * c; + int64_t block = 256; + int64_t grid = std::min((num_outputs - 1) / block + 1, (int64_t)65536); + + int64_t num_inputs = n * t * h * w * c; + int64_t num_weights = c * kernel_t * kernel_h * kernel_w; + int64_t smem = 0; + + // Range check to avoid overflow in CUDA kernels. + assert((num_inputs <= std::numeric_limits::max()) && + "Input tensor is too large."); + assert((num_outputs <= std::numeric_limits::max()) && + "Output tensor is too large."); + assert((num_weights <= 1024*8) && + "Weight tensor is too large."); + + assert((padding_t * 2 + t <= std::numeric_limits::max()) && + "Padded input tensor is too large."); + assert((padding_h * 2 + h <= std::numeric_limits::max()) && + "Padded input tensor is too large."); + assert((padding_w * 2 + w <= std::numeric_limits::max()) && + "Padded input tensor is too large."); + + + using accscalar_t = float; + using Telement = half; + {% if csize == 0 %} + using scalar_t = float4; + c = c/8; + num_outputs = num_outputs/8; + #define element_in_Tio 8 + {% elif csize == 2 %} + using scalar_t = half2; + c =c/2; + num_outputs = num_outputs/2; + #define element_in_Tio 2 + {% else %} + using scalar_t = half; + #define element_in_Tio 1 + {% endif %} + + DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(3, 1) + DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(-1, 1) + DWCONV3D_FORWARD_DISPATCH_OTHERS +} + +#undef DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION +#undef DWCONV3D_FORWARD_DISPATCH_OTHERS +#undef CUDA_KERNEL_LOOP +} // namespace + +void {{function_name}} ( + void* in_ptr, + void* weight_ptr, + void* out_ptr, + int64_t* p_kt, + int64_t* p_kh, + int64_t* p_kw, + int stride_t, + int stride_h, + int stride_w, + int padding_t, + int padding_h, + int padding_w, + int dilation_t, + int dilation_h, + int dilation_w, + int64_t* p_batch, + int64_t* p_in_ch, + int64_t* p_in_t, + int64_t* p_in_h, + int64_t* p_in_w, + int64_t* p_out_ch, + int64_t* p_out_t, + int64_t* p_out_h, + int64_t* p_out_w, + cudaStream_t stream +) { + int kt = *p_kt; + int kh = *p_kh; + int kw = *p_kw; + int batch = *p_batch; + int in_ch = *p_in_ch; + int in_t = *p_in_t; + int in_h = *p_in_h; + int in_w = *p_in_w; + int out_ch = *p_out_ch; + int out_t = *p_out_t; + int out_h = *p_out_h; + int out_w = *p_out_w; + + conv_depthwise3d_launcher( + (const half*)in_ptr, + (const half*)weight_ptr, + (half*)out_ptr, + kt, + kh, + kw, + stride_t, + stride_h, + stride_w, + padding_t, + padding_h, + padding_w, + dilation_t, + dilation_h, + dilation_w, + batch, + in_ch, + in_t, + in_h, + in_w, + out_t, + out_h, + out_w, + stream + ); + + return; +} +""" +) + + +@registry.reg("cuda.depthwise_conv3d.gen_function") +def gen_function(func_attrs): + func_name = func_attrs["name"] + csize = func_attrs["group"] % 8 + return SRC_TEMPLATE.render(function_name=func_name, csize=csize) + + +@registry.reg("cuda.depthwise_conv3d.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + return common.gen_function_decl(func_name) + + +@registry.reg("cuda.depthwise_conv3d.func_call") +def gen_function_call(func_attrs, indent=" "): + return common.gen_function_call(func_attrs, indent) diff --git a/python/aitemplate/backend/cuda/elementwise/__init__.py b/python/aitemplate/backend/cuda/elementwise/__init__.py index 0bf6e473f..18bff2803 100644 --- a/python/aitemplate/backend/cuda/elementwise/__init__.py +++ b/python/aitemplate/backend/cuda/elementwise/__init__.py @@ -15,6 +15,6 @@ """ (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. """ -from . import fused_elementwise +from . import fused_elementwise, int_elementwise -__all__ = ["fused_elementwise"] +__all__ = ["fused_elementwise", "int_elementwise"] diff --git a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh index 2adddd531..07d1650f5 100644 --- a/python/aitemplate/backend/cuda/elementwise/custom_math.cuh +++ b/python/aitemplate/backend/cuda/elementwise/custom_math.cuh @@ -296,4 +296,162 @@ __device__ half2 hmin2_nan(const half2 a, const half2 b) { #endif } +// pow impl +__device__ half hpow(const half a, const half b); + +__device__ half2 h2pow(const half2 a, const half2 b) { + half b1 = __low2half(b); + half b2 = __high2half(b); + if (b1 != b2) { + half a1 = __low2half(a); + half a2 = __high2half(a); + half c1 = hpow(a1, b1); + half c2 = hpow(a2, b2); + return __halves2half2(c1, c2); + } + + // New special cases can be added if needed, such as + // an powi for cases where b is an integer + if (__hbeq2(b, half2(0.0, 0.0))) { + return half2(1.0, 1.0); + } + if (__hbeq2(b, half2(1.0, 1.0))) { + return a; + } + if (__hbeq2(b, half2(2.0, 2.0))) { + return __hmul2(a, a); + } + if (__hbeq2(b, half2(3.0, 3.0))) { + return __hmul2(__hmul2(a, a), a); + } + if (__hbeq2(b, half2(0.5, 0.5))) { + return h2sqrt(a); + } + if (__hbeq2(b, half2(-0.5, -0.5))) { + return h2rsqrt(a); + } + if (__hbeq2(b, half2(-1.0, -1.0))) { + return __h2div(half2(1.0, 1.0), a); + } + if (__hbeq2(b, half2(-2.0, -2.0))) { + return __h2div(half2(1.0, 1.0), __hmul2(a, a)); + } + + half a1 = __low2half(a); + half a2 = __high2half(a); + + // low 16 bits + half c1 = + static_cast(pow(static_cast(a1), static_cast(b1))); + // high 16 bits + half c2 = + static_cast(pow(static_cast(a2), static_cast(b2))); + return __halves2half2(c1, c2); +} + +__device__ half hpow(const half a, const half b) { + if (b == half(0.0)) { + return half(1.0); + } + if (b == half(1.0)) { + return a; + } + if (b == half(2.0)) { + return a * a; + } + if (b == half(3.0)) { + return a * a * a; + } + if (b == half(0.5)) { + return hsqrt(a); + } + if (b == half(-0.5)) { + return hrsqrt(a); + } + if (b == half(-1.0)) { + return half(1.0) / a; + } + if (b == half(-2.0)) { + return half(1.0) / (a * a); + } + return static_cast(pow(static_cast(a), static_cast(b))); +} + +__device__ float fpow(const float a, const float b) { + if (b == float(0.0)) { + return float(1.0); + } + if (b == float(1.0)) { + return a; + } + if (b == float(2.0)) { + return a * a; + } + if (b == float(3.0)) { + return a * a * a; + } + if (b == float(0.5)) { + return sqrt(a); + } + if (b == float(-0.5)) { + return rsqrt(a); + } + if (b == float(-1.0)) { + return float(1.0) / a; + } + if (b == float(-2.0)) { + return float(1.0) / (a * a); + } + return static_cast( + pow(static_cast(a), static_cast(b))); +} + +// +// GELU function definitions implemented as described by +// Hendrycks, D., and Gimpel, K. in +// "Gaussian Error Linear Units (GELUs)." (2020) +// https://arxiv.org/pdf/1606.08415.pdf +// +// Floating-point constants are Taylor coefficients described in the paper. +// +__device__ half hgelu(const half a) { + cutlass::epilogue::thread::GELU gelu_op; + return static_cast(gelu_op(static_cast(a))); +} + +__device__ float fgelu(const float a) { + cutlass::epilogue::thread::GELU gelu_op; + return gelu_op(a); +} + +__device__ half h_fast_gelu(const half a) { + cutlass::epilogue::thread::GELU_taylor gelu_op; + return static_cast(gelu_op(static_cast(a))); +} + +__device__ float f_fast_gelu(const float a) { + cutlass::epilogue::thread::GELU_taylor gelu_op; + return gelu_op(a); +} + +__device__ float fsoftplus( + const float a, + const float beta, + const float threshold) { + return (a * beta > threshold) ? a : log1pf(expf(a * beta)) / beta; +} + +__device__ half hsoftplus(const half a, const half beta, const half threshold) { + half one_val = one(); + return __hgt(__hmul(a, beta), threshold) + ? a + : __hdiv(hlog(__hadd(one_val, hexp(__hmul(a, beta)))), beta); +} + +__device__ half2 +h2softplus(const half2 a, const half2 beta, const half2 threshold) { + return half2( + hsoftplus(a.x, beta.x, threshold.x), hsoftplus(a.y, beta.y, threshold.y)); +} + #endif diff --git a/python/aitemplate/backend/cuda/elementwise/fused_elementwise.py b/python/aitemplate/backend/cuda/elementwise/fused_elementwise.py index f25013aec..667310726 100644 --- a/python/aitemplate/backend/cuda/elementwise/fused_elementwise.py +++ b/python/aitemplate/backend/cuda/elementwise/fused_elementwise.py @@ -29,6 +29,7 @@ #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/constants.h" +#include "cutlass/epilogue/thread/activation.h" """ diff --git a/python/aitemplate/backend/cuda/elementwise/int_elementwise.py b/python/aitemplate/backend/cuda/elementwise/int_elementwise.py new file mode 100644 index 000000000..ad9be8b98 --- /dev/null +++ b/python/aitemplate/backend/cuda/elementwise/int_elementwise.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platform, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +IntElementwise codegen for CUDA. +""" + +import jinja2 + +from ....compiler.base import IntVarTensor +from ... import registry + +from ...backend_spec import CPUBackendSpec + + +INT_VAR_FUNC_TEMPLATE = jinja2.Template( + """ + {{lhs}} = {{rhs}}; +""" +) + + +@registry.reg("cuda.int_elementwise.gen_function") +def dummpy_int_elementwise_gen_function(func_attrs): + return "" + + +@registry.reg("cuda.int_elementwise.func_decl") +def dummpy_int_elementwise_gen_function_decl(func_attrs): + return "" + + +@registry.reg("cuda.int_elementwise.func_call") +def int_elementwise_gen_function_call(func_attrs, indent): + """Generates int_elementwise function call.""" + func_enum = func_attrs["func"] + inputs = func_attrs["inputs"] + outputs = func_attrs["outputs"] + assert ( + len(outputs) == 1 + ), f"Elementwise op for IntVarTensor should only generate 1 output, got {len(outputs)}" + input_params_vec = [] + for inp in inputs: + assert isinstance( + inp, IntVarTensor + ), f"only inputs of IntVarTensor are allowed for OP with output of IntVarTensor, got type{inp}" + input_params_vec.append(inp._attrs["int_var"]._attrs["name"]) + backend_spec = CPUBackendSpec() + op = backend_spec.func_enum_to_func_name.get(func_enum) + rhs = op.join(input_params_vec) + lhs = outputs[0]._attrs["name"] + func_call = INT_VAR_FUNC_TEMPLATE.render( + lhs=lhs, + rhs=rhs, + ) + return func_call diff --git a/python/aitemplate/backend/cuda/embedding/bert_embeddings.py b/python/aitemplate/backend/cuda/embedding/bert_embeddings.py index 19b7ec384..e62826889 100644 --- a/python/aitemplate/backend/cuda/embedding/bert_embeddings.py +++ b/python/aitemplate/backend/cuda/embedding/bert_embeddings.py @@ -21,6 +21,7 @@ import jinja2 from ... import registry +from ...backend_spec import CUDASpec # pylint: disable=C0301 @@ -52,6 +53,7 @@ warpReduceSum(val); if (lane == 0) { +#pragma unroll shared[wid] = val[0]; } @@ -59,6 +61,7 @@ // blockDim.x is round up to multiples of 32 bool is_mask = threadIdx.x < (blockDim.x / 32); +#pragma unroll val[0] = is_mask ? shared[lane] : (T)(0.0f); warpReduceSum(val); @@ -79,7 +82,7 @@ return (cutlass::fast_tanh(val * 0.5f) + 1.0f) * 0.5f; } -template +template __global__ void bert_embeddings_kernel( uint4* output, INDEX_T* input_ids, @@ -95,9 +98,10 @@ const int64_t type_vocab_size, const int64_t max_position_embeddings, const float eps) { + constexpr int num_elems_in_uint4 = sizeof(uint4) / sizeof(ElemT); const int tid = threadIdx.x; const int bid = blockIdx.x; - const int embedding_dim_div_8 = embedding_dim / 8; + const int embedding_dim_div_n = embedding_dim / num_elems_in_uint4; const int64_t input_id = input_ids[bid]; const int64_t token_type_id = token_type_ids[bid]; @@ -110,37 +114,37 @@ return; } - word_embeddings = word_embeddings + input_id * embedding_dim_div_8; + word_embeddings = word_embeddings + input_id * embedding_dim_div_n; token_type_embeddings = - token_type_embeddings + token_type_id * embedding_dim_div_8; - position_embeddings = position_embeddings + position_id * embedding_dim_div_8; + token_type_embeddings + token_type_id * embedding_dim_div_n; + position_embeddings = position_embeddings + position_id * embedding_dim_div_n; uint4 word_embedding{0, 0, 0, 0}; uint4 token_type_embedding{0, 0, 0, 0}; uint4 position_embedding{0, 0, 0, 0}; - if (tid < embedding_dim_div_8) { + if (tid < embedding_dim_div_n) { word_embedding = word_embeddings[tid]; token_type_embedding = token_type_embeddings[tid]; position_embedding = position_embeddings[tid]; } uint4 embedding{0, 0, 0, 0}; - half* word_emb_vec = reinterpret_cast(&word_embedding); - half* token_emb_vec = reinterpret_cast(&token_type_embedding); - half* pos_emb_vec = reinterpret_cast(&position_embedding); + ElemT* word_emb_vec = reinterpret_cast(&word_embedding); + ElemT* token_emb_vec = reinterpret_cast(&token_type_embedding); + ElemT* pos_emb_vec = reinterpret_cast(&position_embedding); - half* emb_vec = reinterpret_cast(&embedding); + ElemT* emb_vec = reinterpret_cast(&embedding); // layernorm __shared__ float s_mean, s_variance; float local_sums[1] = {0.0f}; #pragma unroll - for (int i = 0; i < 8; i++) { + for (int i = 0; i < num_elems_in_uint4; i++) { float sum = word_emb_vec[i] + token_emb_vec[i] + pos_emb_vec[i]; local_sums[0] += sum; - emb_vec[i] = (half)sum; + emb_vec[i] = static_cast(sum); } if (blockDim.x <= 32) { @@ -155,9 +159,9 @@ local_sums[0] = 0.0f; - if (tid < embedding_dim_div_8) { + if (tid < embedding_dim_div_n) { #pragma unroll - for (int i = 0; i < 8; i++) { + for (int i = 0; i < num_elems_in_uint4; i++) { float val = emb_vec[i]; local_sums[0] += (val - s_mean) * (val - s_mean); } @@ -173,13 +177,13 @@ } __syncthreads(); - if (tid < embedding_dim_div_8) { + if (tid < embedding_dim_div_n) { uint4 local_gamma = gamma[tid]; - half* gamma_vec = reinterpret_cast(&local_gamma); + ElemT* gamma_vec = reinterpret_cast(&local_gamma); uint4 local_beta = beta[tid]; - half* beta_vec = reinterpret_cast(&local_beta); + ElemT* beta_vec = reinterpret_cast(&local_beta); #pragma unroll - for (int i = 0; i < 8; i++) { + for (int i = 0; i < num_elems_in_uint4; i++) { emb_vec[i] = normalize( (float)emb_vec[i], s_mean, @@ -190,23 +194,23 @@ } // write to output - if (tid < embedding_dim_div_8) { - output = output + bid * embedding_dim_div_8; + if (tid < embedding_dim_div_n) { + output = output + bid * embedding_dim_div_n; output[tid] = embedding; } } -template +template void bert_embeddings_launcher( - half* output, + ElemT* output, INDEX_T* input_ids, INDEX_T* token_type_ids, INDEX_T* position_ids, - half* word_embeddings, - half* token_type_embeddings, - half* position_embeddings, - half* gamma, - half* beta, + ElemT* word_embeddings, + ElemT* token_type_embeddings, + ElemT* position_embeddings, + ElemT* gamma, + ElemT* beta, const int64_t indices_num, const int64_t embedding_dim, const int64_t vocab_size, @@ -214,17 +218,21 @@ const int64_t max_position_embeddings, const float eps, cudaStream_t stream) { - if (embedding_dim % 8 != 0) { - throw std::runtime_error("embedding dim must be multiple of 8"); + constexpr int num_elems_in_uint4 = sizeof(uint4) / sizeof(ElemT); + if (embedding_dim % num_elems_in_uint4 != 0) { + throw std::runtime_error( + "embedding dim must be multiple of num_elems_in_uint4: " + + std::to_string(num_elems_in_uint4) + ); } dim3 grid(indices_num); // round up to multiple of 32 - int64_t num_threads = embedding_dim / 8; + int64_t num_threads = embedding_dim / num_elems_in_uint4; num_threads = (num_threads + 31) / 32 * 32; dim3 block(num_threads); - bert_embeddings_kernel<<>>( + bert_embeddings_kernel<{{elem_input_type}}, INDEX_T><<>>( reinterpret_cast(output), input_ids, token_type_ids, @@ -245,16 +253,16 @@ {{func_signature}} { - bert_embeddings_launcher<{{index_type}}>( - output, + bert_embeddings_launcher<{{elem_input_type}}, {{index_type}}>( + static_cast<{{elem_input_type}}*>(output), input_ids, token_type_ids, position_ids, - word_embeddings, - token_type_embeddings, - position_embeddings, - gamma, - beta, + static_cast<{{elem_input_type}}*>(word_embeddings), + static_cast<{{elem_input_type}}*>(token_type_embeddings), + static_cast<{{elem_input_type}}*>(position_embeddings), + static_cast<{{elem_input_type}}*>(gamma), + static_cast<{{elem_input_type}}*>(beta), indices_num, embedding_dim, vocab_size, @@ -270,15 +278,15 @@ FUNC_SIGNATURE = jinja2.Template( """ -void {{func_name}}(half* output, +void {{func_name}}(void* output, {{index_type}}* input_ids, {{index_type}}* token_type_ids, {{index_type}}* position_ids, - half* word_embeddings, - half* token_type_embeddings, - half* position_embeddings, - half* gamma, - half* beta, + void* word_embeddings, + void* token_type_embeddings, + void* position_embeddings, + void* gamma, + void* beta, const int64_t indices_num, const int64_t embedding_dim, const int64_t vocab_size, @@ -342,9 +350,14 @@ def python_int_dtype_to_c_dtype(dtype): @registry.reg("cuda.bert_embeddings.gen_function") def bert_embeddings_gen_function(func_attrs: Dict[str, Any]) -> str: + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][3]._attrs["dtype"] + ) dtype = python_int_dtype_to_c_dtype(func_attrs["inputs"][0]._attrs["dtype"]) return FUNC_TEMPLATE.render( index_type=dtype, + elem_input_type=elem_input_type, func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], index_type=dtype, @@ -363,10 +376,6 @@ def bert_embeddings_gen_function_decl(func_attrs: Dict[str, Any]) -> str: ) -FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( - "reinterpret_cast(&({{name}}->raw()))" -) - FUNC_CALL_INT64_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") FUNC_CALL_INT32_PARAM_TEMPLATE = jinja2.Template("reinterpret_cast({{name}})") @@ -405,26 +414,18 @@ def bert_embeddings_gen_function_call(func_attrs: Dict[str, Any], indent=" ") - max_position_embeddings = position_embeddings._size(0).value() eps = func_attrs["eps"] - output_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=func_attrs["outputs"][0]._attrs["name"] - ) + output_str = func_attrs["outputs"][0]._attrs["name"] input_ids_str = get_int_param_template(input_ids) token_type_ids_str = get_int_param_template(token_type_ids) position_ids_str = get_int_param_template(position_ids) - word_embeddings_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=word_embeddings._attrs["name"] - ) - token_type_embeddings_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=token_type_embeddings._attrs["name"] - ) - position_embeddings_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=position_embeddings._attrs["name"] - ) + word_embeddings_str = word_embeddings._attrs["name"] + token_type_embeddings_str = token_type_embeddings._attrs["name"] + position_embeddings_str = position_embeddings._attrs["name"] - gamma_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=gamma._attrs["name"]) - beta_str = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=beta._attrs["name"]) + gamma_str = gamma._attrs["name"] + beta_str = beta._attrs["name"] return FUNC_CALL_TEMPLATE.render( func_name=func_attrs["name"], diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py index 604984059..3c3873c83 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/__init__.py @@ -13,6 +13,18 @@ # limitations under the License. # -from . import bmm_rcr_softmax, gemm_rcr_bias_softmax, gemm_rcr_softmax +from . import ( + bmm_rcr_softmax, + dual_gemm_rcr_fast_gelu, + dual_gemm_rcr_silu, + gemm_rcr_bias_softmax, + gemm_rcr_softmax, +) -__all__ = ["bmm_rcr_softmax", "gemm_rcr_bias_softmax", "gemm_rcr_softmax"] +__all__ = [ + "bmm_rcr_softmax", + "gemm_rcr_bias_softmax", + "gemm_rcr_softmax", + "dual_gemm_rcr_silu", + "dual_gemm_rcr_fast_gelu", +] diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_common_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_common_softmax.py index 4a63ff1fc..af5753b3a 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_common_softmax.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_common_softmax.py @@ -67,7 +67,7 @@ {{indent}} {{d_ptr}}, {{indent}} {{n_ptr}}, {{indent}} {{soft_ptr}}, -{{indent}} global_workspace, +{{indent}} global_workspace_, {{indent}} {{a_dim0_ptr}}, {{indent}} {{a_dim1_ptr}}, {{indent}} {{a_dim2_ptr}}, @@ -182,7 +182,7 @@ def gen_profiler( ) common.add_profiler(file_pairs, workdir, op_type, op_name, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) def gen_function_decl(func_attrs): diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py index 751a19a84..4a4b745a9 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py @@ -101,7 +101,7 @@ def bmm_rcr_softmax_config(func_attrs, dtype="float16"): @registry.reg("cuda.bmm_rcr_softmax.gen_profiler") def gen_profiler(func_attrs, workdir, dim_info_dict): """Generate code for profiling""" - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, dim_info_dict, diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_dual_gemm.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_dual_gemm.py new file mode 100644 index 000000000..fdcf0e741 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_dual_gemm.py @@ -0,0 +1,458 @@ +# Copyright (c) Meta Platform, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common codegen functions for dual gemm. +D0 = epilogue0(X @ B0, C0) +D1 = epilogue0(X @ B1, C1) +D2 = element_wise(D0, D1) +""" + +from functools import partial +from hashlib import sha1 +from typing import Any, Dict + +import jinja2 + +from ...backend_spec import CUDASpec +from ...common import gemm_common +from ...target import Target +from ..gemm_universal import common + +# pylint: disable=C0301,C0415,R1705 + +EXTRA_CODE = jinja2.Template( + """ +#include "device/dual_gemm.h" +#include "thread/left_silu_and_mul.h" + +typename cutlass::TensorRef nullptr_ref{}; +decltype(nullptr_ref) ref_B0, ref_B1; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +""" +) + +# HACK: we don't record different permutation shape, +# because it has little impact on execution time compared. +# Therefore, no matter what permutation shape it is, +# we will use the same kernel, i.e. the first generated perm_shape +# At runtime, the kernel will be regenerated and thus the correctness will not be affected. +KERNEL_KEY_TEMPLATE = jinja2.Template( + """ +cutlass_{{opcode_class_name}}_{{extended_name}}_{{threadblock}}_{{layout}}_align_{{align_ab}}_{{align_c}} +""" +) + +TENSOR_DECL_TEMPLATE = jinja2.Template( + """ + int64_t a_ptr_sz = a_dim0 * a_dim1; + int64_t b_ptr_sz = b_dim0 * b_dim1; + int64_t c_ptr_sz = c_dim0 * c_dim1; + + // The value 1 is used to force ptr_max_sz to be non-zero + int64_t ptr_max_sz = std::max({1, a_ptr_sz, b_ptr_sz, c_ptr_sz}); + // TODO: special pool size for A100 L2 cache 40M + // need to tune it for other devices + int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); + + memory_pool->AllocateTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 + +{% if has_bias %} + memory_pool->AllocateTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 3 +{% endif %} + +""" +) + +EXEC_TEMPLATE = jinja2.Template( + """ +// TODO: cast to right dtype +//{{indent}}using ElementComputeEpilogue = typename {{instance}}::ElementAccumulator; +{{indent}}using ElementCompute = typename {{instance}}::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; + +{{indent}}typename {{instance}}::Arguments arguments{ + +{{problem_args}} + +{{indent}}}; +{% if is_profiler %} +{{indent}}// https://youtu.be/-Rp7UPbhErE +{{indent}}size_t workspace_size = gemm_op.get_workspace_size(arguments); +{{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); + +{{indent}}workspace = local_workspace.get(); +{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% else %} +{{indent}}{{instance}} gemm_op; +{% endif %} + +{{indent}} auto status = gemm_op.can_implement(arguments); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = gemm_op.initialize(arguments, workspace, stream); +{{indent}}CUTLASS_CHECK(status); +{{indent}}status = gemm_op(stream); +{{indent}}CUTLASS_CHECK(status); + +{{indent}}return; +""" +) + + +def kernel_name(op, func_attrs): + """Returns kernel_name given input cutlass op_instance and operator attrs.""" + + from cutlass_lib import library + + threadblock = op.tile_description.procedural_name() + extended_name = op.extended_name() + opcode_class_name = library.OpcodeClassNames[ + op.tile_description.math_instruction.opcode_class + ] + layout = op.layout_name() + align_ab = op.A.alignment + align_c = op.C.alignment + + name = KERNEL_KEY_TEMPLATE.render( + threadblock=threadblock, + extended_name=extended_name, + opcode_class_name=opcode_class_name, + layout=layout, + align_ab=align_ab, + align_c=align_c, + ) + return name.replace("\n", "") + + +def extract_config(f_proc_op, func_attrs): + return common.extract_config(f_proc_op, partial(kernel_name, func_attrs=func_attrs)) + + +def dual_gemm_instance( + op_def: str, func_attrs: Dict[str, Any], for_profiler: bool +) -> str: + tmp = op_def.replace( + "GemmIdentityThreadblockSwizzle<8>", "GemmIdentityThreadblockSwizzle<1>" + ) + return tmp + + +def emit_instance( + op, + for_profiler, + f_instance_convertor=dual_gemm_instance, + emit_kernel=False, + func_attrs=None, +): + import cutlass_lib + + emiter = cutlass_lib.gemm_operation.EmitDualGemmInstance() + op_def = emiter.emit(op) + op_def = f_instance_convertor(op_def, func_attrs, for_profiler) + return op_def + + +def default_fproc_f16( + *, + op, + a_layout, + b_layout, + c_layout, + epiligue_name, + epiligue2_name, + permute_layout=None, +): + import copy + + import cutlass_lib + + ret = [] + data_type = cutlass_lib.library.DataType.f16 + acc_type = cutlass_lib.library.DataType.f32 + # check target use fp16 acc + if "use_fp16_acc" in Target.current()._kwargs: + if Target.current()._kwargs["use_fp16_acc"]: + acc_type = cutlass_lib.library.DataType.f16 + if ( + op.A.element == data_type + and op.B.element == data_type + and op.C.element == data_type + and op.accumulator_type() == acc_type + and op.A.layout == a_layout + and op.B.layout == b_layout + ): + op = copy.deepcopy(op) + # set output major + op.C.layout = c_layout + # set epilogue + op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epiligue_name] + op.epilogue_functor2 = cutlass_lib.library.EpilogueFunctorName[epiligue2_name] + op.element_epilogue = acc_type + if permute_layout is not None: + op.permute_layout = cutlass_lib.library.EpiloguePermuteLayoutName[ + permute_layout + ] + # set C alignment + for i in [8, 4, 2, 1]: + op = copy.deepcopy(op) + op.C.alignment = i + ret.append(op) + return ret + + +def make_fproc_f16(func_attrs, layout): + """ + This function sets a callback for processing the epilogue of the kernel + associated with func_attrs. + """ + + def fproc_f16(op): + a_layout, b_layout, c_layout = layout.cutlass_lib_layouts() + return default_fproc_f16( + op=op, + a_layout=a_layout, + b_layout=b_layout, + c_layout=c_layout, + epiligue_name=func_attrs["epilogue"], + epiligue2_name=func_attrs["epilogue2"], + ) + + func_attrs["op_instance"] = extract_config(fproc_f16, func_attrs) + + +def gen_function( + func_attrs, + src_template, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + f_instance_convertor=dual_gemm_instance, + emit_kernel=False, + support_split_k=False, + input_addr_calculator="", + output_addr_calculator="", + extra_code="", +): + func_name = func_attrs["name"] + exec_path = func_attrs["exec_path"] + op_instance = func_attrs["op_instance"] + inst_def_flag = set() + instances = {} + instance_decl = "" + for exec_item in exec_path.values(): + fname = "f" + sha1(exec_item.exec_cond.encode()).hexdigest() + algo = exec_item.algo + if algo not in inst_def_flag: + config = emit_instance( + op_instance[algo], + for_profiler=False, + f_instance_convertor=f_instance_convertor, + emit_kernel=emit_kernel, + func_attrs=func_attrs, + ) + inst_def_flag.add(algo) + else: + config = "" + inst = common.INSTANCE_TEMPLATE.render( + config=config, name=fname, config_name=common.extract_config_name(config) + ) + instances[exec_item.exec_cond] = inst + instance_decl += inst + shape_eval_func = gemm_common.gen_shape_eval_code( + indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + exec_paths = "" + for key, _ in instances.items(): + fname = "f" + sha1(key.encode()).hexdigest() + program = EXEC_TEMPLATE.render( + indent=" ", + instance=fname, + problem_args=problem_args, + support_split_k=support_split_k, + ) + exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program) + exec_paths += exec_inst + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + ) + return src_template.render( + instances=instance_decl, + function_name=func_name, + dtype="cutlass::half_t", + shape_eval=shape_eval_func, + input_addr_calculator=input_addr_calculator, + output_addr_calculator=output_addr_calculator, + input_output_checks=input_output_checks, + exec_paths=exec_paths, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + output_ndims=output_ndims, + support_split_k=support_split_k, + has_d=common.has_d(func_attrs), + has_d1=common.has_d1(func_attrs), + extra_code=extra_code, + ) + + +def gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + src_template, + problem_args_template, + args_parser_template, + emit_kernel=False, + support_split_k=False, + output_addr_calculator="", + bias_ptr_arg=None, + extra_code="", +): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + elem_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + op_type = func_attrs["op"] + op_instance = func_attrs["op_instance"] + + ndims = 2 + adims = ["&a_dim" + str(i) for i in range(ndims)] + bdims = ["&b_dim" + str(i) for i in range(ndims)] + cdims = ["&c_dim" + str(i) for i in range(ndims)] + shape_func = gemm_common.gen_shape_eval_code( + indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True + ) + + has_bias = bias_ptr_arg is not None + instance_name_base = "GemmInstance" + exec_program = EXEC_TEMPLATE.render( + indent=" ", + instance=instance_name_base, + is_profiler=True, + support_split_k=support_split_k, + problem_args=problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ), + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + ) + + function_name = "gemm" + instances = [] + benchmark_instances = [] + for instance_idx, (op_name, op) in enumerate(op_instance.items()): + config = emit_instance( + op, for_profiler=True, emit_kernel=emit_kernel, func_attrs=func_attrs + ) + config_name = common.extract_config_name(config) + instance_name = f"{instance_name_base}_{instance_idx}" + gemm_op = f"gemm_op_{instance_idx}" + instance = common.INSTANCE_TEMPLATE.render( + config_name=config_name, name=instance_name, config=config + ) + benchmark_instance = common.BENCHMARK_INSTANCE_TEMPLATE.render( + indent=" ", + instance_name=instance_name, + gemm_op=gemm_op, + gemm_op_name=op_name, + func_name=f"benchmark_{function_name}", + a_ptr="memory_pool->RequestTensorByIdx(0)", + b_ptr="memory_pool->RequestTensorByIdx(1)", + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr="memory_pool->RequestTensorByIdx(2)", + support_split_k=support_split_k, + split_k="split_k", + adims=adims, + bdims=bdims, + cdims=cdims, + ) + instances.append(instance) + benchmark_instances.append(benchmark_instance) + op_func = src_template.render( + is_profiler=True, + instances="\n".join(instances), + function_name=function_name, + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + output_addr_calculator=output_addr_calculator, + support_split_k=support_split_k, + extra_code=extra_code, + ) + benchmark_adims = ["a_dim" + str(i) for i in range(ndims)] + benchmark_bdims = ["b_dim" + str(i) for i in range(ndims)] + benchmark_cdims = ["c_dim" + str(i) for i in range(ndims)] + func_call = common.FUNC_CALL_TEMPLATE.render( + is_profiler=True, + func_name=function_name, + a_ptr="a_ptr", + b_ptr="b_ptr", + has_bias=has_bias, + bias_ptr="bias_ptr", + c_ptr="c_ptr", + split_k="split_k", + adims=benchmark_adims, + bdims=benchmark_bdims, + cdims=benchmark_cdims, + ) + # TODO: Render args_parse by caller. + args_parse = ( + args_parser_template + if isinstance(args_parser_template, str) + else args_parser_template.render() + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + has_bias=has_bias, + support_split_k=support_split_k, + args_parse=args_parse, + function_name=function_name, + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + func_call=func_call, + tensor_decl=TENSOR_DECL_TEMPLATE.render(has_bias=has_bias), + benchmark_instances="\n".join(benchmark_instances), + elem_type=elem_type, + ) + # FIXME: remove file_pairs once we have make -j ready for building + # an entire graph + file_pairs = [] + common.add_profiler(file_pairs, workdir, op_type, profiler_filename, code) + # build + return common.build_profiler(file_pairs) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py index ff5e4b084..5f172d6ea 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py @@ -172,7 +172,7 @@ {{indent}} {{d_ptr}}, {{indent}} {{n_ptr}}, {{indent}} {{soft_ptr}}, -{{indent}} global_workspace, +{{indent}} global_workspace_, {{indent}} {{split_k}}, {% for dim in adims %} {{indent}} {{dim}}, @@ -342,11 +342,11 @@ for (auto & event : events) { cudaEventCreate(&event); } - cudaEventRecord(events[0]); + cudaEventRecord(events[0], stream); for (int i = 0; i < 5; ++i) { {{func_call}} } - cudaEventRecord(events[1]); + cudaEventRecord(events[1], stream); cudaEventSynchronize(events[1]); float runtime_ms = 0; cudaEventElapsedTime(&runtime_ms, events[0], events[1]); @@ -535,4 +535,4 @@ def gen_profiler( ) common.add_profiler(file_pairs, workdir, op_type, op_name, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py new file mode 100644 index 000000000..5c626cbb8 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platform, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = FAST_GELU(GEMM_RCR(A, B)) * GEMM_RCR(A, B1) +where A[RowMajor][M, K], B[ColMajor][N, K], B1[RowMajor][N, K] +""" +import jinja2 + +from ... import registry +from ...backend_spec import CUDASpec +from ..gemm_universal import common, common_bias +from ..gemm_universal.layout import RCR +from . import common_dual_gemm + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +# used for real execution +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmCoord{M, N, K}, + {({{elem_input_type}}*)a_ptr, LayoutA(K)}, + {({{elem_input_type}}*)b_ptr, LayoutB(K)}, + ref_B0, + nullptr_ref, // D0 + {({{elem_input_type}}*)bias_ptr, LayoutB(K)}, // B1 + ref_B1, + nullptr_ref, // D1 + {({{elem_output_type}}*)c_ptr, LayoutC(N)}, // D2 + {ElementCompute(1), ElementCompute(0)}, + {ElementCompute(1), ElementCompute(0)}, + {}, + 1 // kSplitKSerial +""" +) + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t M = std::atoi(argv[1]); + int64_t N = std::atoi(argv[2]); + int64_t K = std::atoi(argv[3]); + int64_t split_k = std::atoi(argv[4]); + + int64_t a_dim0 = M; + int64_t a_dim1 = K; + int64_t b_dim0 = N; + int64_t b_dim1 = K; + int64_t c_dim0 = M; + int64_t c_dim1 = N; +""" +) + +# for profiler, no need to include TensorAccessor +PROFILER_PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmCoord{M, N, K}, + {({{elem_input_type}}*)a_ptr, LayoutA(K)}, + {({{elem_input_type}}*)b_ptr, LayoutB(K)}, + ref_B0, + nullptr_ref, // D0 + {({{elem_input_type}}*)bias_ptr, LayoutB(K)}, // B1 + ref_B1, + nullptr_ref, // D1 + {({{elem_output_type}}*)c_ptr, LayoutC(N)}, // D2 + {ElementCompute(1), ElementCompute(0)}, + {ElementCompute(1), ElementCompute(0)}, + {}, + 1 // kSplitKSerial +""" +) + + +EXTRA_CODE = jinja2.Template( + """ +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" +#include "device/dual_gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LeftFastGeluAndMul { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params{}; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LeftFastGeluAndMul(Params const &/*params*/) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::GELU_taylor gelu; + cutlass::multiplies mul; + auto gelu_lhs = gelu(converted_lhs); + return compute_to_output(mul(gelu_lhs, converted_rhs)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()( + ElementAccumulator const& lhs, + ElementAccumulator const& rhs + ) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::GELU_taylor gelu; + cutlass::multiplies mul; + auto gelu_lhs = gelu(convert_lhs); + return ElementOutput(mul(gelu_lhs, convert_rhs)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + + +typename cutlass::TensorRef nullptr_ref{}; +decltype(nullptr_ref) ref_B0, ref_B1; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +""" +) + + +@registry.reg("cuda.dual_gemm_rcr_fast_gelu.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + common_dual_gemm.make_fproc_f16(func_attrs, RCR) + + +def common_gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + src_template, + problem_args_template, + bias_ptr_arg=None, + extra_code="", +): + output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="*b_dim0" + ) + return common_dual_gemm.gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + src_template, + problem_args_template, + ARGS_PARSER_TEMPLATE, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=output_addr_calculator, + bias_ptr_arg=bias_ptr_arg, + extra_code=extra_code, + ) + + +@registry.reg("cuda.dual_gemm_rcr_fast_gelu.gen_profiler") +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + common_bias.SRC_TEMPLATE, + PROFILER_PROBLEM_ARGS_TEMPLATE, + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.dual_gemm_rcr_fast_gelu.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + problem_args_template=None, +): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + + if problem_args_template is None: + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) + else: + problem_args = problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + return common_dual_gemm.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", output_accessor=func_attrs["output_accessors"][0] + ), + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.dual_gemm_rcr_fast_gelu.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.dual_gemm_rcr_fast_gelu.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.dual_gemm_rcr_fast_gelu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_silu.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_silu.py new file mode 100644 index 000000000..211259e9e --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_silu.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platform, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for +C = SILU(GEMM_RCR(A, B)) * GEMM_RCR(A, B1) +where A[RowMajor][M, K], B[ColMajor][N, K], B1[RowMajor][N, K] +""" +import jinja2 + +from ... import registry +from ...backend_spec import CUDASpec +from ..gemm_universal import common, common_bias +from ..gemm_universal.layout import RCR +from . import common_dual_gemm + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + + +# used for real execution +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmCoord{M, N, K}, + {({{elem_input_type}}*)a_ptr, LayoutA(K)}, + {({{elem_input_type}}*)b_ptr, LayoutB(K)}, + ref_B0, + nullptr_ref, // D0 + {({{elem_input_type}}*)bias_ptr, LayoutB(K)}, // B1 + ref_B1, + nullptr_ref, // D1 + {({{elem_output_type}}*)c_ptr, LayoutC(N)}, // D2 + {ElementCompute(1), ElementCompute(0)}, + {ElementCompute(1), ElementCompute(0)}, + {}, + 1 // kSplitKSerial +""" +) + + +ARGS_PARSER_TEMPLATE = jinja2.Template( + """ + int64_t M = std::atoi(argv[1]); + int64_t N = std::atoi(argv[2]); + int64_t K = std::atoi(argv[3]); + int64_t split_k = std::atoi(argv[4]); + + int64_t a_dim0 = M; + int64_t a_dim1 = K; + int64_t b_dim0 = N; + int64_t b_dim1 = K; + int64_t c_dim0 = M; + int64_t c_dim1 = N; +""" +) + +# for profiler, no need to include TensorAccessor +PROFILER_PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmCoord{M, N, K}, + {({{elem_input_type}}*)a_ptr, LayoutA(K)}, + {({{elem_input_type}}*)b_ptr, LayoutB(K)}, + ref_B0, + nullptr_ref, // D0 + {({{elem_input_type}}*)bias_ptr, LayoutB(K)}, // B1 + ref_B1, + nullptr_ref, // D1 + {({{elem_output_type}}*)c_ptr, LayoutC(N)}, // D2 + {ElementCompute(1), ElementCompute(0)}, + {ElementCompute(1), ElementCompute(0)}, + {}, + 1 // kSplitKSerial +""" +) + + +@registry.reg("cuda.dual_gemm_rcr_silu.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + common_dual_gemm.make_fproc_f16(func_attrs, RCR) + + +def common_gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + src_template, + problem_args_template, + bias_ptr_arg=None, + extra_code="", +): + output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="*b_dim0" + ) + return common_dual_gemm.gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + src_template, + problem_args_template, + ARGS_PARSER_TEMPLATE, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=output_addr_calculator, + bias_ptr_arg=bias_ptr_arg, + extra_code=extra_code, + ) + + +@registry.reg("cuda.dual_gemm_rcr_silu.gen_profiler") +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + common_bias.SRC_TEMPLATE, + PROFILER_PROBLEM_ARGS_TEMPLATE, + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", + extra_code=common_dual_gemm.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.dual_gemm_rcr_silu.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, + problem_args_template=None, +): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + + if problem_args_template is None: + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) + else: + problem_args = problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + return common_dual_gemm.gen_function( + func_attrs, + common_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + emit_kernel=True, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", output_accessor=func_attrs["output_accessors"][0] + ), + extra_code=common_dual_gemm.EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.dual_gemm_rcr_silu.func_decl") +def gen_function_decl(func_attrs): + func_name = func_attrs["name"] + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + return common_bias.FUNC_DECL_TEMPLATE.render( + func_name=func_name, + input_ndims=input_ndims, + weight_ndims=weight_ndims, + support_split_k=True, + ) + + +@registry.reg("cuda.dual_gemm_rcr_silu.func_call") +def gen_function_call(func_attrs, indent=" "): + bias = func_attrs["inputs"][2] + return common.gen_function_call( + func_attrs, indent, bias_ptr_arg=bias._attrs["name"] + ) + + +@registry.reg("cuda.dual_gemm_rcr_silu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py index eb3fcde49..45d69ac00 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py @@ -87,7 +87,7 @@ def common_gen_profiler( output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( stride_dim="*b_dim0" ) - common_softmax.gen_profiler( + return common_softmax.gen_profiler( func_attrs, workdir, dim_info_dict, diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/include/gemm_with_softmax.h b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/include/gemm_with_softmax.h index 3b168b3d8..d5e7351a9 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/include/gemm_with_softmax.h +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/include/gemm_with_softmax.h @@ -1,17 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// namespace cutlass { template < diff --git a/python/aitemplate/backend/cuda/gemm_special/bmm_rcr_n1.py b/python/aitemplate/backend/cuda/gemm_special/bmm_rcr_n1.py index 5582ee24e..42e203069 100644 --- a/python/aitemplate/backend/cuda/gemm_special/bmm_rcr_n1.py +++ b/python/aitemplate/backend/cuda/gemm_special/bmm_rcr_n1.py @@ -43,9 +43,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - {{elem_input_type}}*, - {{elem_input_type}}*, - {{elem_input_type}}*, + void*, + void*, + void*, {% for i in range(3) %} int64_t*, {% endfor %} @@ -92,9 +92,9 @@ EXEC_TEMPLATE = jinja2.Template( """ {{indent}}bmm_rcr_n1_launcher<{{elem_input_type}}, {{read_vec_type}}, {{K}}>( -{{indent}} a_ptr, -{{indent}} b_ptr, -{{indent}} c_ptr, +{{indent}} ({{elem_input_type}}*)a_ptr, +{{indent}} ({{elem_input_type}}*)b_ptr, +{{indent}} ({{elem_input_type}}*)c_ptr, {{indent}} B, {{indent}} M, {{indent}} alpha, @@ -447,9 +447,9 @@ } // namespace void {{function_name}} ( - {{elem_input_type}}* a_ptr, - {{elem_input_type}}* b_ptr, - {{elem_input_type}}* c_ptr, + void* a_ptr, + void* b_ptr, + void* c_ptr, {% for i in range(3) %} int64_t *a_dim{{loop.index0}}, {% endfor %} @@ -496,8 +496,10 @@ def _get_original_dim_val(func_attrs, input_idx, dim): bk = _get_original_dim_val(func_attrs, 1, 2) assert ak == bk, f"ak is not equal to bk. ak: {ak}, bk: {bk}" - elem_input_type = "cutlass::half_t" backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) vec_lens = list(zip(*backend_spec.read_num_elements_to_backend_type))[0][:-1] alignment = tensor_accessor_codegen.find_max_alignment( ak, func_attrs["input_accessors"] @@ -560,8 +562,17 @@ def _get_original_dim_val(func_attrs, input_idx, dim): @registry.reg("cuda.bmm_rcr_n1.func_decl") def gen_function_decl(func_attrs): func_name = func_attrs["name"] + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) return FUNC_DECL_TEMPLATE.render( - func_name=func_name, elem_input_type="cutlass::half_t" + func_name=func_name, + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, ) diff --git a/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py b/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py index de29a6ab7..eb5cfe109 100644 --- a/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py +++ b/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py @@ -23,6 +23,7 @@ import jinja2 from ... import registry +from ...backend_spec import CUDASpec from ...common import gemm_common from ..gemm_universal import common @@ -31,9 +32,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, + void*, {% for i in range(3) %} int64_t*, {% endfor %} @@ -71,10 +72,10 @@ EXEC_TEMPLATE = jinja2.Template( """ -{{indent}}bmm_rrr_k1_tanh_launcher( -{{indent}} a_ptr, -{{indent}} b_ptr, -{{indent}} c_ptr, +{{indent}}bmm_rrr_k1_tanh_launcher<{{elem_input_type}}>( +{{indent}} ({{elem_input_type}}*)a_ptr, +{{indent}} ({{elem_input_type}}*)b_ptr, +{{indent}} ({{elem_input_type}}*)c_ptr, {{indent}} B, {{indent}} M, {{indent}} N, @@ -86,6 +87,7 @@ SRC_TEMPLATE = jinja2.Template( """ +#include #include #include #include "cutlass/util/host_tensor.h" @@ -97,6 +99,10 @@ namespace { +template +__device__ T fast_tanh(T x); + +template <> __device__ half fast_tanh(half x) { #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) @@ -108,7 +114,7 @@ #endif } -template +template __global__ void bmm_rrr_k1_tanh_kernel(const float4* a_ptr, const float4* b_ptr, float4* c_ptr, @@ -116,58 +122,75 @@ const int M, const int N) { // TODO: check boundary - half tmp[64]; + constexpr int num_elems_in_float4 = sizeof(float4) / sizeof(ElemT); + ElemT tmp[num_elems_in_float4 * num_elems_in_float4]; int idx = blockIdx.x * num_thread + threadIdx.x; int m = idx % M; int b = idx / M; int a_idx_base = b * M + m; float4 a_vec = __ldg(a_ptr + a_idx_base); - half* a_vec_ptr = (half*)(&a_vec); + ElemT* a_vec_ptr = (ElemT*)(&a_vec); for (int n = 0; n < N; ++n) { int b_idx_base = b * N + n; float4 b_vec = __ldg(b_ptr + b_idx_base); - half* b_vec_ptr = (half*)(&b_vec); - for (int i = 0; i < 8; ++i) { + ElemT* b_vec_ptr = (ElemT*)(&b_vec); + for (int i = 0; i < num_elems_in_float4; ++i) { CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < 8; ++j) { - tmp[i * 8 + j] = fast_tanh(__hmul(a_vec_ptr[i], b_vec_ptr[j])); + for (int j = 0; j < num_elems_in_float4; ++j) { + tmp[i * num_elems_in_float4 + j] = fast_tanh(__hmul(a_vec_ptr[i], b_vec_ptr[j])); } } CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 8; ++i) { - int c_idx = (b * M * 8 + m * 8 + i) * N + n; - c_ptr[c_idx] = *((const float4*)(tmp + i * 8)); + for (int i = 0; i < num_elems_in_float4; ++i) { + int c_idx = (b * M * num_elems_in_float4 + m * num_elems_in_float4 + i) * N + n; + c_ptr[c_idx] = *((const float4*)(tmp + i * num_elems_in_float4)); } } } -void bmm_rrr_k1_tanh_launcher(cutlass::half_t* a_ptr, - cutlass::half_t* b_ptr, - cutlass::half_t* c_ptr, +template +void bmm_rrr_k1_tanh_launcher(ElemT* a_ptr, + ElemT* b_ptr, + ElemT* c_ptr, int B, int M, int N, cudaStream_t stream) { + constexpr int num_elems_in_float4 = sizeof(float4) / sizeof(ElemT); + if (M % num_elems_in_float4 != 0) { + auto msg = std::string("Got error: ") + std::to_string(M) + "%" + + std::to_string(num_elems_in_float4) + " != 0 " + + " at " + __FILE__ + ": " + std::to_string(__LINE__); + std::cerr << msg << std::endl; + throw std::runtime_error(msg); + } + if (N % num_elems_in_float4 != 0) { + auto msg = std::string("Got error: ") + std::to_string(N) + "%" + + std::to_string(num_elems_in_float4) + " != 0 " + + " at " + __FILE__ + ": " + std::to_string(__LINE__); + std::cerr << msg << std::endl; + throw std::runtime_error(msg); + } const int nthread = 256; dim3 thread_block(nthread); - dim3 grid(B * M / nthread / 8); - bmm_rrr_k1_tanh_kernel<<>>( + dim3 grid(B * M / nthread / num_elems_in_float4); + bmm_rrr_k1_tanh_kernel<<>>( (const float4*)a_ptr, (const float4*)b_ptr, (float4*) c_ptr, B, - M / 8, - N / 8 + M / num_elems_in_float4, + N / num_elems_in_float4 ); } } // namespace void {{function_name}} ( - cutlass::half_t* a_ptr, - cutlass::half_t* b_ptr, - cutlass::half_t* c_ptr, + void* a_ptr, + void* b_ptr, + void* c_ptr, {% for i in range(3) %} int64_t *a_dim{{loop.index0}}, {% endfor %} @@ -199,7 +222,11 @@ def gen_function(func_attrs, exec_cond_template, dim_info_dict): weight_ndims=3, output_ndims=3, ) - exec_paths = EXEC_TEMPLATE.render() + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + exec_paths = EXEC_TEMPLATE.render(elem_input_type=elem_input_type) return SRC_TEMPLATE.render( function_name=func_name, shape_function=shape_func, diff --git a/python/aitemplate/backend/cuda/gemm_special/gemm_rrr_small_nk.py b/python/aitemplate/backend/cuda/gemm_special/gemm_rrr_small_nk.py index 81ed764e8..b53b74f37 100644 --- a/python/aitemplate/backend/cuda/gemm_special/gemm_rrr_small_nk.py +++ b/python/aitemplate/backend/cuda/gemm_special/gemm_rrr_small_nk.py @@ -28,6 +28,7 @@ import jinja2 from ... import registry +from ...backend_spec import CUDASpec from ...common import gemm_common from ...target import Target from ..gemm_universal import common @@ -38,9 +39,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, + void*, {% for i in range(a_ndim) %} int64_t*, {% endfor %} @@ -81,10 +82,10 @@ EXEC_TEMPLATE = jinja2.Template( """ -{{indent}}gemm_rrr_small_nk_launcher<{{N}}, {{K}}>( -{{indent}} a_ptr, -{{indent}} b_ptr, -{{indent}} c_ptr, +{{indent}}gemm_rrr_small_nk_launcher<{{elem_input_type}}, {{N}}, {{K}}>( +{{indent}} ({{elem_input_type}}*)a_ptr, +{{indent}} ({{elem_input_type}}*)b_ptr, +{{indent}} ({{elem_input_type}}*)c_ptr, {{indent}} M, {{indent}} use_fp16_acc, {{indent}} stream @@ -96,6 +97,8 @@ SRC_TEMPLATE = jinja2.Template( """ +#include +#include #include #include #include "cutlass/util/host_tensor.h" @@ -107,10 +110,8 @@ // B matrix: K x N // C tile: 8 x N template -__global__ void gemm_rrr_small_nk_kernel(float4* a_ptr, - float4* b_ptr, - float4* c_ptr, - int M) { +__global__ void gemm_rrr_small_nk_kernel_half( + float4* a_ptr, float4* b_ptr, float4* c_ptr, int M) { int idx = blockIdx.x * num_thread + threadIdx.x; if (idx >= (M + 7) / 8) { @@ -223,40 +224,48 @@ } // N <= 8, K <= 8 -template -void gemm_rrr_small_nk_launcher(cutlass::half_t* a_ptr, - cutlass::half_t* b_ptr, - cutlass::half_t* c_ptr, +template +void gemm_rrr_small_nk_launcher(ElemT* a_ptr, + ElemT* b_ptr, + ElemT* c_ptr, int M, bool use_fp16_acc, cudaStream_t stream) { + constexpr int num_elems_in_float4 = sizeof(float4) / sizeof(ElemT); const int nthread = 256; dim3 thread_block(nthread); - const int n_element_per_t = nthread * 8; + constexpr int n_element_per_t = nthread * num_elems_in_float4; dim3 grid((M + n_element_per_t - 1) / n_element_per_t); - if(use_fp16_acc) { - gemm_rrr_small_nk_kernel<<>>( - (float4*)a_ptr, - (float4*)b_ptr, - (float4*)c_ptr, - M - ); + if constexpr (std::is_same::value) { + if(use_fp16_acc) { + gemm_rrr_small_nk_kernel_half<<>>( + (float4*)a_ptr, + (float4*)b_ptr, + (float4*)c_ptr, + M + ); + } else { + gemm_rrr_small_nk_kernel_half<<>>( + (float4*)a_ptr, + (float4*)b_ptr, + (float4*)c_ptr, + M + ); + } } else { - gemm_rrr_small_nk_kernel<<>>( - (float4*)a_ptr, - (float4*)b_ptr, - (float4*)c_ptr, - M - ); + auto msg = std::string("Got error: unsupported elem type ") + + " at " + __FILE__ + ": " + std::to_string(__LINE__); + std::cerr << msg << std::endl; + throw std::runtime_error(msg); } } } // namespace void {{function_name}} ( - cutlass::half_t* a_ptr, - cutlass::half_t* b_ptr, - cutlass::half_t* c_ptr, + void* a_ptr, + void* b_ptr, + void* c_ptr, {% for i in range(a_ndim) %} int64_t *a_dim{{loop.index0}}, {% endfor %} @@ -299,11 +308,17 @@ def gen_function(func_attrs, exec_cond_template, dim_info_dict): weight_ndims=2, output_ndims=c_ndim, ) + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) if n == 0 or k == 0: # avoid "zero-sized variable not allowed in device code" error exec_paths = "" else: - exec_paths = EXEC_TEMPLATE.render(indent=" ", N=n, K=k) + exec_paths = EXEC_TEMPLATE.render( + indent=" ", elem_input_type=elem_input_type, N=n, K=k + ) return SRC_TEMPLATE.render( function_name=func_name, shape_function=shape_func, diff --git a/python/aitemplate/backend/cuda/gemm_universal/__init__.py b/python/aitemplate/backend/cuda/gemm_universal/__init__.py index c07983128..9d04403bc 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/__init__.py +++ b/python/aitemplate/backend/cuda/gemm_universal/__init__.py @@ -42,6 +42,7 @@ gemm_rcr_bias_sigmoid_mul_tanh, gemm_rcr_bias_swish, gemm_rcr_bias_tanh, + gemm_rcr_fast_gelu, gemm_rcr_permute, gemm_rrr, gemm_rrr_permute, diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr.py index 25ad9e9a8..b8e3fa6c1 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr.py @@ -44,22 +44,30 @@ def _get_problem_info(**kwargs): @registry.reg("cuda.bmm_ccr.config") def bmm_ccr_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.ColumnMajor, b_layout=cutlass_lib.library.LayoutType.ColumnMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = common.extract_config(fproc_f16) + func_attrs["op_instance"] = common.extract_config(fproc) @registry.reg("cuda.bmm_ccr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -74,16 +82,21 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): a_dims=a_dims, b_dims=b_dims, c_dims=c_dims ) - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -97,12 +110,16 @@ def gen_function( exec_cond_template, dim_info_dict, ): - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) return bmm_common.gen_function( func_attrs, exec_cond_template, diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr_add.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr_add.py index ea9ff0510..fe8e605f0 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr_add.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_ccr_add.py @@ -29,7 +29,7 @@ def bmm_ccr_add_config(func_attrs, dtype="float16"): @registry.reg("cuda.bmm_ccr_add.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -45,7 +45,7 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): ) mm_info = bmm_ccr._get_problem_info( - bias_ptr="d_ptr", + bias_ptr="(d_ptr)", alpha_value=func_attrs.get("alpha", 1), beta_value=1, ) @@ -54,11 +54,14 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): d_shapes = func_attrs["input_accessors"][2].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -73,14 +76,18 @@ def gen_function( dim_info_dict, ): mm_info = bmm_ccr._get_problem_info( - bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + bias_ptr="(d_ptr)", + alpha_value=func_attrs.get("alpha", 1), + beta_value=1, ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes d_shapes = func_attrs["input_accessors"][2].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) return bmm_common.gen_function( func_attrs, exec_cond_template, diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py index 7b22806e3..6a00b0fc5 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py @@ -19,6 +19,7 @@ import jinja2 +from ...backend_spec import CUDASpec from ...common import gemm_common from . import common @@ -55,12 +56,12 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, {% if has_d %} - cutlass::half_t*, + void*, {% endif %} - cutlass::half_t*, + void*, uint8_t*, {% if support_split_k %} int, @@ -85,6 +86,9 @@ {{indent}}{ {{indent}}{{local_dim_defs}} {{indent}}{{func_name}}( +{% if is_profiler %} +{{indent}} gemm_op, +{% endif %} {{indent}} {{a_ptr}}, {{indent}} {{b_ptr}}, {% if has_d %} @@ -94,7 +98,7 @@ {{indent}} {{bias_ptr}}, {% endif %} {{indent}} {{c_ptr}}, -{{indent}} global_workspace, +{{indent}} global_workspace_, {% for dim in a_dims_ptr %} {{indent}} {{dim}}, {% endfor %} @@ -135,14 +139,14 @@ // need to tune it for other devices int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); - memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 - memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 - memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 + memory_pool->AllocateTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 {% if has_bias %} - memory_pool->AllocateHalfTensor(c_dim2, mem_pool_sz); // bias_ptr: index 3 + memory_pool->AllocateTensor(c_dim2, mem_pool_sz); // bias_ptr: index 3 {% endif %} {% if has_d %} - memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d_ptr: index 3 (no bias) or 4 + memory_pool->AllocateTensor(c_ptr_sz, mem_pool_sz); // d_ptr: index 3 (no bias) or 4 {% endif %} """ ) @@ -189,10 +193,10 @@ def _update_stride_info(mm_info, a_shapes, b_shapes, bias_shapes=None): {{mm_info.problem_size}}, {{mm_info.batch_size}}, {ElementComputeEpilogue({{mm_info.alpha_value}}), ElementComputeEpilogue({{mm_info.beta_value}})}, - (void*) {{mm_info.a_ptr}}, - (void*) {{mm_info.b_ptr}}, - (void*) {{mm_info.bias_ptr}}, - (void*) {{mm_info.c_ptr}}, + {{mm_info.a_ptr}}, + {{mm_info.b_ptr}}, + {{mm_info.bias_ptr}}, + {{mm_info.c_ptr}}, {{mm_info.a_batch_stride}}, {{mm_info.b_batch_stride}}, {{mm_info.bias_batch_stride}}, @@ -232,6 +236,7 @@ def _fill(arr, idx, val): def gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args, @@ -240,6 +245,10 @@ def gen_profiler( ): op_type = func_attrs["op"] op_instance = func_attrs["op_instance"] + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) has_d = False if "has_d" in func_attrs: has_d = func_attrs["has_d"] @@ -247,75 +256,114 @@ def gen_profiler( a_ndims = len(func_attrs["input_accessors"][0].original_shapes) b_ndims = len(func_attrs["input_accessors"][1].original_shapes) c_ndims = len(func_attrs["output_accessors"][0].original_shapes) + a_dims_ptr = [f"&a_dim{idx}" for idx in range(a_ndims)] + b_dims_ptr = [f"&b_dim{idx}" for idx in range(b_ndims)] + c_dims_ptr = [f"&c_dim{idx}" for idx in range(c_ndims)] shape_func = gemm_common.gen_shape_eval_code( indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True ) - file_pairs = [] has_bias = bias_ptr_arg is not None assert not (has_d and has_bias) - for op_name, op in op_instance.items(): + instance_name_base = "GemmInstance" + exec_program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=instance_name_base, + is_profiler=True, + problem_args=problem_args, + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + ) + + function_name = "bmm" + instances = [] + benchmark_instances = [] + for instance_idx, (op_name, op) in enumerate(op_instance.items()): config = common.emit_instance(op, for_profiler=True) config_name = common.extract_config_name(config) - name = "GemmInstance" + instance_name = f"{instance_name_base}_{instance_idx}" + gemm_op = f"gemm_op_{instance_idx}" instance = common.INSTANCE_TEMPLATE.render( - config_name=config_name, name=name, config=config + config_name=config_name, name=instance_name, config=config ) - exec_program = common.EXEC_TEMPLATE.render( + benchmark_instance = common.BENCHMARK_INSTANCE_TEMPLATE.render( indent=" ", - instance=name, - is_profiler=True, - problem_args=problem_args, - ) - input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( - input_ndims=a_ndims, - weight_ndims=b_ndims, - output_ndims=c_ndims, - ) - op_func = src_template.render( - instances=instance, - function_name="bmm", - input_ndims=a_ndims, - weight_ndims=b_ndims, - output_ndims=c_ndims, - shape_eval=shape_func, - input_output_checks=input_output_checks, - exec_paths=exec_program, - has_d=has_d, - ) - a_dims_ptr = [f"&a_dim{idx}" for idx in range(a_ndims)] - b_dims_ptr = [f"&b_dim{idx}" for idx in range(b_ndims)] - c_dims_ptr = [f"&c_dim{idx}" for idx in range(c_ndims)] - func_call = FUNC_CALL_TEMPLATE.render( - func_name="bmm", - a_ptr="memory_pool->RequestHalfTensorByIdx(0)", - b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + instance_name=instance_name, + gemm_op=gemm_op, + gemm_op_name=op_name, + func_name=f"benchmark_{function_name}", + a_ptr="memory_pool->RequestTensorByIdx(0)", + b_ptr="memory_pool->RequestTensorByIdx(1)", has_bias=has_bias, bias_ptr=bias_ptr_arg, - c_ptr="memory_pool->RequestHalfTensorByIdx(2)", - d_ptr="memory_pool->RequestHalfTensorByIdx(%d)" % (4 if has_bias else 3), + c_ptr="memory_pool->RequestTensorByIdx(2)", + d_ptr="memory_pool->RequestTensorByIdx(%d)" % (4 if has_bias else 3), has_d=has_d, - a_dims_ptr=a_dims_ptr, - b_dims_ptr=b_dims_ptr, - c_dims_ptr=c_dims_ptr, - ) - code = common.PROFILER_TEMPLATE.render( - op_func=op_func, - args_parse=args_parser, - func_call=func_call, - name=name, - tensor_decl=TENSOR_DECL_TEMPLATE.render( - name=name, - a_ndims=a_ndims, - b_ndims=b_ndims, - c_ndims=c_ndims, - has_d=has_d, - has_bias=has_bias, - ), + adims=a_dims_ptr, + bdims=b_dims_ptr, + cdims=c_dims_ptr, ) - common.add_profiler(file_pairs, workdir, op_type, op_name, code) + instances.append(instance) + benchmark_instances.append(benchmark_instance) + op_func = src_template.render( + is_profiler=True, + instances="\n".join(instances), + function_name=function_name, + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + has_d=has_d, + ) + benchmark_adims = [f"a_dim{idx}" for idx in range(a_ndims)] + benchmark_bdims = [f"b_dim{idx}" for idx in range(b_ndims)] + benchmark_cdims = [f"c_dim{idx}" for idx in range(c_ndims)] + func_call = FUNC_CALL_TEMPLATE.render( + is_profiler=True, + func_name=function_name, + a_ptr="a_ptr", + b_ptr="b_ptr", + has_bias=has_bias, + bias_ptr="bias_ptr", + c_ptr="c_ptr", + d_ptr="d_ptr", + has_d=has_d, + a_dims_ptr=benchmark_adims, + b_dims_ptr=benchmark_bdims, + c_dims_ptr=benchmark_cdims, + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + has_bias=has_bias, + has_d=has_d, + args_parse=args_parser, + function_name=function_name, + func_call=func_call, + name=instance_name_base, + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + tensor_decl=TENSOR_DECL_TEMPLATE.render( + a_ndims=a_ndims, + b_ndims=b_ndims, + c_ndims=c_ndims, + has_d=has_d, + has_bias=has_bias, + ), + benchmark_instances="\n".join(benchmark_instances), + elem_type=elem_type, + ) + # FIXME: remove file_pairs once we have make -j ready for building + # an entire graph + file_pairs = [] + common.add_profiler(file_pairs, workdir, op_type, profiler_filename, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) def gen_function_decl(func_attrs): diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_crr.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_crr.py index 62d6eee96..213234342 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_crr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_crr.py @@ -46,22 +46,30 @@ def _get_problem_info(**kwargs): @registry.reg("cuda.bmm_crr.config") def bmm_crr_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.ColumnMajor, b_layout=cutlass_lib.library.LayoutType.RowMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = common.extract_config(fproc_f16) + func_attrs["op_instance"] = common.extract_config(fproc) @registry.reg("cuda.bmm_crr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -76,16 +84,21 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): a_dims=a_dims, b_dims=b_dims, c_dims=c_dims ) - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -99,12 +112,16 @@ def gen_function( exec_cond_template, dim_info_dict, ): - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) return bmm_common.gen_function( func_attrs, exec_cond_template, diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_crr_add.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_crr_add.py index 2767af9b0..ce62a6a1e 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_crr_add.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_crr_add.py @@ -31,7 +31,7 @@ def bmm_crr_add_config(func_attrs, dtype="float16"): @registry.reg("cuda.bmm_crr_add.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -47,18 +47,23 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): ) mm_info = bmm_crr._get_problem_info( - bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + bias_ptr="d_ptr", + alpha_value=func_attrs.get("alpha", 1), + beta_value=1, ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes d_shapes = func_attrs["input_accessors"][2].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -73,14 +78,18 @@ def gen_function( dim_info_dict, ): mm_info = bmm_crr._get_problem_info( - bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + bias_ptr="d_ptr", + alpha_value=func_attrs.get("alpha", 1), + beta_value=1, ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes d_shapes = func_attrs["input_accessors"][2].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) return bmm_common.gen_function( func_attrs, exec_cond_template, diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py index 582bfd38e..222522396 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py @@ -15,6 +15,7 @@ """ Common functions and templates for bmm_permute-family ops """ +from ...backend_spec import CUDASpec from ...common import gemm_common from ..gemm_universal import common, common_bias @@ -26,6 +27,7 @@ def gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args, @@ -37,6 +39,10 @@ def gen_profiler( """Generate code for profiling""" op_type = func_attrs["op"] op_instance = func_attrs["op_instance"] + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) has_d = False if "has_d" in func_attrs: has_d = func_attrs["has_d"] @@ -44,14 +50,32 @@ def gen_profiler( a_ndims = len(func_attrs["input_accessors"][0].original_shapes) b_ndims = len(func_attrs["input_accessors"][1].original_shapes) c_ndims = len(func_attrs["output_accessors"][0].original_shapes) + a_dims_ptr = [f"&a_dim{idx}" for idx in range(a_ndims)] + b_dims_ptr = [f"&b_dim{idx}" for idx in range(b_ndims)] + c_dims_ptr = [f"&c_dim{idx}" for idx in range(c_ndims)] shape_func = gemm_common.gen_shape_eval_code( indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True ) - file_pairs = [] has_bias = bias_ptr_arg is not None assert not (has_d and has_bias) - for op_name, op in op_instance.items(): + instance_name_base = "GemmInstance" + exec_program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=instance_name_base, + is_profiler=True, + problem_args=problem_args, + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + ) + + function_name = "bmm" + instances = [] + benchmark_instances = [] + for instance_idx, (op_name, op) in enumerate(op_instance.items()): config = common_permute.emit_instance( op, for_profiler=True, @@ -59,66 +83,87 @@ def gen_profiler( func_attrs=func_attrs, ) config_name = common.extract_config_name(config) - name = "GemmInstance" + instance_name = f"{instance_name_base}_{instance_idx}" + gemm_op = f"gemm_op_{instance_idx}" instance = common.INSTANCE_TEMPLATE.render( - config_name=config_name, name=name, config=config + config_name=config_name, name=instance_name, config=config ) - exec_program = common.EXEC_TEMPLATE.render( + benchmark_instance = common.BENCHMARK_INSTANCE_TEMPLATE.render( indent=" ", - instance=name, - is_profiler=True, - problem_args=problem_args, - ) - input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( - input_ndims=a_ndims, - weight_ndims=b_ndims, - output_ndims=c_ndims, - ) - op_func = src_template.render( - instances=instance, - function_name="bmm", - input_ndims=a_ndims, - weight_ndims=b_ndims, - output_ndims=c_ndims, - shape_eval=shape_func, - input_output_checks=input_output_checks, - exec_paths=exec_program, - has_d=has_d, - extra_code=extra_code, - ) - a_dims_ptr = [f"&a_dim{idx}" for idx in range(a_ndims)] - b_dims_ptr = [f"&b_dim{idx}" for idx in range(b_ndims)] - c_dims_ptr = [f"&c_dim{idx}" for idx in range(c_ndims)] - func_call = bmm_common.FUNC_CALL_TEMPLATE.render( - func_name="bmm", - a_ptr="memory_pool->RequestHalfTensorByIdx(0)", - b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + instance_name=instance_name, + gemm_op=gemm_op, + gemm_op_name=op_name, + func_name=f"benchmark_{function_name}", + a_ptr="memory_pool->RequestTensorByIdx(0)", + b_ptr="memory_pool->RequestTensorByIdx(1)", has_bias=has_bias, bias_ptr=bias_ptr_arg, - c_ptr="memory_pool->RequestHalfTensorByIdx(2)", - d_ptr="memory_pool->RequestHalfTensorByIdx(%d)" % (4 if has_bias else 3), + c_ptr="memory_pool->RequestTensorByIdx(2)", + d_ptr="memory_pool->RequestTensorByIdx(%d)" % (4 if has_bias else 3), has_d=has_d, - a_dims_ptr=a_dims_ptr, - b_dims_ptr=b_dims_ptr, - c_dims_ptr=c_dims_ptr, - ) - code = common.PROFILER_TEMPLATE.render( - op_func=op_func, - args_parse=args_parser, - func_call=func_call, - name=name, - tensor_decl=bmm_common.TENSOR_DECL_TEMPLATE.render( - name=name, - a_ndims=a_ndims, - b_ndims=b_ndims, - c_ndims=c_ndims, - has_d=has_d, - has_bias=has_bias, - ), + adims=a_dims_ptr, + bdims=b_dims_ptr, + cdims=c_dims_ptr, ) - common.add_profiler(file_pairs, workdir, op_type, op_name, code) + instances.append(instance) + benchmark_instances.append(benchmark_instance) + op_func = src_template.render( + is_profiler=True, + instances="\n".join(instances), + function_name=function_name, + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + has_d=has_d, + extra_code=extra_code, + ) + benchmark_adims = [f"a_dim{idx}" for idx in range(a_ndims)] + benchmark_bdims = [f"b_dim{idx}" for idx in range(b_ndims)] + benchmark_cdims = [f"c_dim{idx}" for idx in range(c_ndims)] + func_call = bmm_common.FUNC_CALL_TEMPLATE.render( + is_profiler=True, + func_name=function_name, + a_ptr="a_ptr", + b_ptr="b_ptr", + has_bias=has_bias, + bias_ptr=bias_ptr_arg, + c_ptr="c_ptr", + d_ptr="d_ptr", + has_d=has_d, + a_dims_ptr=benchmark_adims, + b_dims_ptr=benchmark_bdims, + c_dims_ptr=benchmark_cdims, + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + has_bias=has_bias, + has_d=has_d, + args_parse=args_parser, + function_name=function_name, + func_call=func_call, + name=instance_name_base, + input_ndims=a_ndims, + weight_ndims=b_ndims, + output_ndims=c_ndims, + tensor_decl=bmm_common.TENSOR_DECL_TEMPLATE.render( + a_ndims=a_ndims, + b_ndims=b_ndims, + c_ndims=c_ndims, + has_d=has_d, + has_bias=has_bias, + ), + benchmark_instances="\n".join(benchmark_instances), + elem_type=elem_type, + ) + # FIXME: remove file_pairs once we have make -j ready for building + # an entire graph + file_pairs = [] + common.add_profiler(file_pairs, workdir, op_type, profiler_filename, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) def gen_function_decl(func_attrs): diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr.py index d660f3c61..c8afa49aa 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr.py @@ -19,6 +19,7 @@ """ from ... import registry +from ...backend_spec import CUDASpec from ...common import gemm_common from . import bmm_common, common from .layout import RCR @@ -47,11 +48,11 @@ def _get_default_problem_info(**kwargs): @registry.reg("cuda.bmm_rcr.config") def bmm_rcr_config(func_attrs, dtype="float16"): - common.make_fproc_f16(func_attrs, RCR) + common.make_fproc(func_attrs, RCR) @registry.reg("cuda.bmm_rcr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -66,16 +67,21 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): a_dims=a_dims, b_dims=b_dims, c_dims=c_dims ) - mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info = _get_default_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -89,6 +95,14 @@ def gen_function( exec_cond_template, dim_info_dict, ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + input_a_batch_stride_dim = "M * K" input_a_stride_k_dim = "K" input_a_offset = 0 @@ -151,10 +165,10 @@ def gen_function( bmm_problem_info = bmm_common.Bmm_problem_info( alpha_value=func_attrs.get("alpha", 1), - a_ptr="(a_ptr + input_a_offset)", - b_ptr="(b_ptr + input_b_offset)", - bias_ptr="(c_ptr + output_offset)", - c_ptr="(c_ptr + output_offset)", + a_ptr="(" + elem_input_type + "*)(a_ptr) + input_a_offset", + b_ptr="(" + elem_input_type + "*)(b_ptr) + input_b_offset", + bias_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", + c_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", a_batch_stride="input_a_batch_stride", b_batch_stride="input_b_batch_stride", bias_batch_stride="output_batch_stride", @@ -168,7 +182,9 @@ def gen_function( b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(bmm_problem_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) return bmm_common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr_permute.py index 2dc737be5..17574b62e 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rcr_permute.py @@ -19,6 +19,7 @@ """ from ... import registry +from ...backend_spec import CUDASpec from ...common import gemm_common from . import bmm_common, bmm_permute_common, common, common_permute @@ -27,23 +28,31 @@ @registry.reg("cuda.bmm_rcr_permute.config") def bmm_rcr_permute_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common_permute.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.RowMajor, b_layout=cutlass_lib.library.LayoutType.ColumnMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], permute_layout=func_attrs["layout"], ) - func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + func_attrs["op_instance"] = common_permute.extract_config(fproc, func_attrs) @registry.reg("cuda.bmm_rcr_permute.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -78,9 +87,10 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): mm_info=bmm_problem_info, ) - bmm_permute_common.gen_profiler( + return bmm_permute_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -96,6 +106,14 @@ def gen_function( exec_cond_template, dim_info_dict, ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + input_a_batch_stride_dim = "M * K" input_a_stride_k_dim = "K" input_a_offset = 0 @@ -148,10 +166,10 @@ def gen_function( bmm_problem_info = bmm_common.Bmm_problem_info( alpha_value=func_attrs.get("alpha", 1), - a_ptr="(a_ptr + input_a_offset)", - b_ptr="(b_ptr + input_b_offset)", - bias_ptr="(c_ptr + output_offset)", - c_ptr="(c_ptr + output_offset)", + a_ptr="(" + elem_input_type + "*)(a_ptr) + input_a_offset", + b_ptr="(" + elem_input_type + "*)(b_ptr) + input_b_offset", + bias_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", + c_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", a_batch_stride="input_a_batch_stride", b_batch_stride="input_b_batch_stride", bias_batch_stride="output_batch_stride", diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr.py index bc752b1bb..489059f31 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr.py @@ -46,22 +46,30 @@ def _get_problem_info(**kwargs): @registry.reg("cuda.bmm_rrr.config") def bmm_rrr_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.RowMajor, b_layout=cutlass_lib.library.LayoutType.RowMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = common.extract_config(fproc_f16) + func_attrs["op_instance"] = common.extract_config(fproc) @registry.reg("cuda.bmm_rrr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -76,16 +84,21 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): a_dims=a_dims, b_dims=b_dims, c_dims=c_dims ) - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -99,12 +112,16 @@ def gen_function( exec_cond_template, dim_info_dict, ): - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) return bmm_common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_add.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_add.py index bb8201291..44fbda070 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_add.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_add.py @@ -31,7 +31,7 @@ def bmm_rrr_add_config(func_attrs, dtype="float16"): @registry.reg("cuda.bmm_rrr_add.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -47,18 +47,23 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): ) mm_info = bmm_rrr._get_problem_info( - bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + bias_ptr="d_ptr", + alpha_value=func_attrs.get("alpha", 1), + beta_value=1, ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes d_shapes = func_attrs["input_accessors"][2].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -73,14 +78,18 @@ def gen_function( dim_info_dict, ): mm_info = bmm_rrr._get_problem_info( - bias_ptr="d_ptr", alpha_value=func_attrs.get("alpha", 1), beta_value=1 + bias_ptr="d_ptr", + alpha_value=func_attrs.get("alpha", 1), + beta_value=1, ) a_shapes = func_attrs["input_accessors"][0].original_shapes b_shapes = func_attrs["input_accessors"][1].original_shapes d_shapes = func_attrs["input_accessors"][2].original_shapes bmm_common._update_stride_info(mm_info, a_shapes, b_shapes, d_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) return bmm_common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_permute.py index d1d17ee8d..40a69bd28 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_rrr_permute.py @@ -19,6 +19,7 @@ """ from ... import registry +from ...backend_spec import CUDASpec from ...common import gemm_common from . import bmm_common, bmm_permute_common, common, common_permute @@ -27,23 +28,31 @@ @registry.reg("cuda.bmm_rrr_permute.config") def bmm_rrr_permute_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common_permute.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.RowMajor, b_layout=cutlass_lib.library.LayoutType.RowMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], permute_layout=func_attrs["layout"], ) - func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + func_attrs["op_instance"] = common_permute.extract_config(fproc, func_attrs) @registry.reg("cuda.bmm_rrr_permute.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): a_dims = bmm_common.reverse_dim_info_mapping( dim_info_dict, gemm_common.Source.INPUT, 0 ) @@ -78,9 +87,10 @@ def gen_profiler(func_attrs, workdir, dim_info_dict): mm_info=bmm_problem_info, ) - bmm_permute_common.gen_profiler( + return bmm_permute_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -96,6 +106,14 @@ def gen_function( exec_cond_template, dim_info_dict, ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + input_a_batch_stride_dim = "M * K" input_a_stride_k_dim = "K" input_a_offset = 0 @@ -158,10 +176,10 @@ def gen_function( bmm_problem_info = bmm_common.Bmm_problem_info( alpha_value=func_attrs.get("alpha", 1), - a_ptr="(a_ptr + input_a_offset)", - b_ptr="(b_ptr + input_b_offset)", - bias_ptr="(c_ptr + output_offset)", - c_ptr="(c_ptr + output_offset)", + a_ptr="(" + elem_input_type + "*)(a_ptr) + input_a_offset", + b_ptr="(" + elem_input_type + "*)(b_ptr) + input_b_offset", + bias_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", + c_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", a_batch_stride="input_a_batch_stride", b_batch_stride="input_b_batch_stride", bias_batch_stride="output_batch_stride", @@ -175,7 +193,9 @@ def gen_function( b_shapes = func_attrs["input_accessors"][1].original_shapes bmm_common._update_stride_info(bmm_problem_info, a_shapes, b_shapes) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) return bmm_permute_common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/common.py b/python/aitemplate/backend/cuda/gemm_universal/common.py index 199311035..9c18ab765 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/common.py +++ b/python/aitemplate/backend/cuda/gemm_universal/common.py @@ -27,7 +27,8 @@ from ....compiler.base import IntImm -from ... import builder +from ...backend_spec import CUDASpec + from ...common import gemm_common, tensor_accessor_codegen from ...target import Target @@ -153,13 +154,19 @@ {{instances}} +{% if is_profiler %} +template +void {{function_name}} ( + GemmInstance& gemm_op, +{% else %} void {{function_name}} ( - cutlass::half_t* a_ptr, - cutlass::half_t* b_ptr, +{% endif %} + void* a_ptr, + void* b_ptr, {% if has_d %} - cutlass::half_t* d_ptr, + void* d_ptr, {% endif %} - cutlass::half_t* c_ptr, + void* c_ptr, uint8_t* workspace, {% if support_split_k %} int split_k, @@ -211,13 +218,14 @@ {{problem_args}} {{indent}}}; -{{indent}}{{instance}} gemm_op; {% if is_profiler %} {{indent}}// https://www.youtube.com/watch?v=rRwxfYlgG-M {{indent}}size_t workspace_size = gemm_op.get_workspace_size(arguments); {{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); {{indent}}workspace = local_workspace.get(); {{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% else %} +{{indent}}{{instance}} gemm_op; {% endif %} {{indent}}auto status = gemm_op.can_implement(arguments); {{indent}}CUTLASS_CHECK(status); @@ -233,9 +241,9 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, + void*, uint8_t*, {% if support_split_k %} int, @@ -260,13 +268,16 @@ {{indent}}{ {{indent}}{{local_dim_defs}} {{indent}}{{func_name}}( +{% if is_profiler %} +{{indent}} gemm_op, +{% endif %} {{indent}} {{a_ptr}}, {{indent}} {{b_ptr}}, {% if has_bias %} {{indent}} {{bias_ptr}}, {% endif %} {{indent}} {{c_ptr}}, -{{indent}} global_workspace, +{{indent}} global_workspace_, {{indent}} {{split_k}}, {% for dim in adims %} {{indent}} {{dim}}, @@ -284,6 +295,53 @@ ) +BENCHMARK_INSTANCE_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{indent}} +{{indent}}{{instance_name}} {{gemm_op}}; +{{indent}}const char *gemm_op_name = "{{gemm_op_name}}"; +{{indent}}int ret = 0; +{{indent}}try { +{{indent}}ret = {{func_name}}( +{{indent}} {{gemm_op}}, +{{indent}} gemm_op_name, +{{indent}} {{a_ptr}}, +{{indent}} {{b_ptr}}, +{% if has_bias %} +{{indent}} {{bias_ptr}}, +{% endif %} +{% if has_d %} +{{indent}} {{d_ptr}}, +{% endif %} +{% if has_d1 %} +{{indent}} {{d1_ptr}}, +{% endif %} +{{indent}} {{c_ptr}}, +{{indent}} global_workspace_, +{% if support_split_k %} +{{indent}} {{split_k}}, +{% endif %} +{% for dim in adims %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in bdims %} +{{indent}} {{dim}}, +{% endfor %} +{% for dim in cdims %} +{{indent}} {{dim}}, +{% endfor %} +{{indent}} stream +{{indent}}); +{{indent}}} catch (...) {} +{{indent}}if (ret != 0) +{{indent}} return ret; +{{indent}} +{{indent}}} +""" +) + + TENSOR_DECL_TEMPLATE = jinja2.Template( """ int64_t a_ptr_sz = a_dim0 * a_dim1; @@ -296,12 +354,12 @@ // need to tune it for other devices int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); - memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 - memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 - memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 + memory_pool->AllocateTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 {% if has_bias %} - memory_pool->AllocateHalfTensor(c_dim1, mem_pool_sz); // bias_ptr: index 3 + memory_pool->AllocateTensor(c_dim1, mem_pool_sz); // bias_ptr: index 3 {% endif %} """ @@ -315,6 +373,95 @@ {{op_func}} +template +int benchmark_{{function_name}} ( +{% if is_group_gemm %} + GemmInstance &gemm_op, + const char *gemm_op_name, + int sharedMemPerMultiprocessor, + int multiProcessorCount, + uint8_t* global_workspace_, + int problem_count, + cutlass::gemm::GemmCoord* problem_sizes_device, + void **ptr_A, + void **ptr_B, + void **ptr_C, +{% if has_bias %} + void **ptr_bias, +{% endif %} + int64_t* lda, + int64_t* ldb, + int64_t* ldc, +{% if has_bias %} + int64_t* ldd, +{% endif %} + int occupancy, + cudaStream_t stream + +{% else %} + + GemmInstance &gemm_op, + const char *gemm_op_name, + void* a_ptr, + void* b_ptr, +{% if has_bias %} + void* bias_ptr, +{% endif %} +{% if has_d %} + void* d_ptr, +{% endif %} +{% if has_d1 %} + void* d1_ptr, +{% endif %} + void* c_ptr, + uint8_t* global_workspace_, +{% if support_split_k %} + int split_k, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t* a_dim{{idx}}, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t* b_dim{{idx}}, +{% endfor %} +{% for idx in range(output_ndims) %} + int64_t* c_dim{{idx}}, +{% endfor %} + cudaStream_t stream +{% endif %} + ) { + // warmup + for (int i = 0; i < 5; ++i) { + {{func_call}} + } + cudaEvent_t events[2]; + for (auto & event : events) { + cudaEventCreate(&event); + } + cudaEventRecord(events[0], stream); + for (int i = 0; i < 10; ++i) { + {{func_call}} + } + cudaEventRecord(events[1], stream); + cudaEventSynchronize(events[1]); + float runtime_ms = 0; + cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + for (auto event : events) { + (void)cudaEventDestroy(event); + } + // TODO: output workspace + if (runtime_ms < 0.00001) { + throw std::runtime_error( + "OOB in cutlass." + ); + } + std::cout << "OP:" << gemm_op_name << ","; + std::cout << "TIME:" << runtime_ms << ","; + std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; + return 0; +} + +template struct ProfilerMemoryPool { ProfilerMemoryPool() { std::random_device rd; @@ -328,7 +475,6 @@ } ~ProfilerMemoryPool() {} - template DType* AllocateGaussianTensor(int64_t size) { size_t length = size * sizeof(DType); blobs.emplace_back(length); @@ -345,25 +491,20 @@ } - cutlass::half_t* AllocateHalfGaussianTensor(int64_t size) { - return reinterpret_cast( - AllocateGaussianTensor<__half>(size)); - } - - int AllocateHalfTensor(int64_t size, int64_t copy) { + int AllocateTensor(int64_t size, int64_t copy) { offsets.push_back(0); strides.push_back(size); copies.push_back(copy); - auto ptr = AllocateHalfGaussianTensor(size * copy); + auto ptr = AllocateGaussianTensor(size * copy); ptrs.push_back(reinterpret_cast(ptr)); return ptrs.size() - 1; } - cutlass::half_t* RequestHalfTensorByIdx(int idx) { + DType* RequestTensorByIdx(int idx) { auto copy = copies.at(idx); auto offset = offsets.at(idx); auto stride = strides.at(idx); - cutlass::half_t* ptr = reinterpret_cast(ptrs.at(idx)); + DType* ptr = reinterpret_cast(ptrs.at(idx)); ptr += offset; offset += stride; if (offset == copy * stride) { @@ -387,7 +528,7 @@ int device_idx; cudaDeviceProp device_properties; cudaError_t result = cudaGetDevice(&device_idx); - auto memory_pool = std::make_unique(); + auto memory_pool = std::make_unique>(); if (result != cudaSuccess) { throw std::runtime_error("cudaGetDevice() API call failed."); } @@ -400,41 +541,12 @@ {{args_parse}} - using ElementOutput = typename {{name}}::ElementC; - using ElementInputA = typename {{name}}::ElementA; - using ElementInputB = typename {{name}}::ElementB; - uint8_t* global_workspace = nullptr; + uint8_t* global_workspace_ = nullptr; cudaStream_t stream = nullptr; {{tensor_decl}} - // warmup - for (int i = 0; i < 5; ++i) { - {{func_call}} - } - cudaEvent_t events[2]; - for (auto & event : events) { - cudaEventCreate(&event); - } - cudaEventRecord(events[0]); - for (int i = 0; i < 10; ++i) { - {{func_call}} - } - cudaEventRecord(events[1]); - cudaEventSynchronize(events[1]); - float runtime_ms = 0; - cudaEventElapsedTime(&runtime_ms, events[0], events[1]); - for (auto event : events) { - (void)cudaEventDestroy(event); - } - // TODO: output workspace - if (runtime_ms < 0.00001) { - throw std::runtime_error( - "OOB in cutlass." - ); - } - std::cout << "TIME:" << runtime_ms << std::endl; - std::cout << "WS:" << GLOBAL_WORKSPACE_SIZE << std::endl; + {{benchmark_instances}} return 0; } """ @@ -512,6 +624,11 @@ def update_alignments_in_gemm_instance( epilogue_alignment = tensor_accessor_codegen.find_max_alignment_for_accessor( output_accessor ) + + # if the last dim is dynamic, force align=1 + if not isinstance(output_accessor.original_shapes[-1], IntImm): + epilogue_alignment = 1 + gemm_params = get_gemm_instance_template_params(op_def, kernel_config) epilogue_align_idx = 11 a_align_idx = 17 @@ -592,7 +709,7 @@ def emit_instance( return op_def -def extract_config(f_proc_op): +def extract_config(f_proc_op, f_kernel_name=kernel_name): import cutlass_lib op_kind = cutlass_lib.library.OperationKind.Gemm @@ -606,7 +723,7 @@ def extract_config(f_proc_op): ret = f_proc_op(op) if len(ret) > 0: for op_inst in ret: - key = kernel_name(op_inst) + key = f_kernel_name(op_inst) gemm_ops[key] = op_inst return gemm_ops @@ -636,6 +753,13 @@ def gen_function( output_addr_calculator="", extra_code="", ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) func_name = func_attrs["name"] exec_path = func_attrs["exec_path"] op_instance = func_attrs["op_instance"] @@ -697,6 +821,8 @@ def gen_function( has_d=has_d(func_attrs), has_d1=has_d1(func_attrs), extra_code=extra_code, + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, ) @@ -705,11 +831,10 @@ def build_profiler(file_pairs): if target.disable_profiler_codegen(): file_pairs = [] elif target.use_dummy_profiling_results(): - # if it is circle CI only random build 2 profiler + # if it is circle CI only random build 2 profilers random.shuffle(file_pairs) file_pairs = file_pairs[:2] - compile_engine = builder.Builder() - compile_engine.build_objs(file_pairs, target.compile_cmd(executable=True)) + return file_pairs def add_profiler(file_pairs, workdir, op_type, output_name, code): @@ -728,6 +853,7 @@ def add_profiler(file_pairs, workdir, op_type, output_name, code): def gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args_template, @@ -739,6 +865,16 @@ def gen_profiler( ): op_type = func_attrs["op"] op_instance = func_attrs["op_instance"] + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + elem_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) ndims = 2 adims = ["&a_dim" + str(i) for i in range(ndims)] bdims = ["&b_dim" + str(i) for i in range(ndims)] @@ -747,68 +883,117 @@ def gen_profiler( indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True ) - file_pairs = [] has_bias = bias_ptr_arg is not None - for op_name, op in op_instance.items(): + instance_name_base = "GemmInstance" + exec_program = EXEC_TEMPLATE.render( + indent=" ", + instance=instance_name_base, + is_profiler=True, + support_split_k=support_split_k, + problem_args=problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ), + ) + input_output_checks = INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + ) + + function_name = "gemm" + instances = [] + benchmark_instances = [] + for instance_idx, (op_name, op) in enumerate(op_instance.items()): config = emit_instance(op, for_profiler=True) config_name = extract_config_name(config) - name = "GemmInstance" + instance_name = f"{instance_name_base}_{instance_idx}" + gemm_op = f"gemm_op_{instance_idx}" instance = INSTANCE_TEMPLATE.render( - config_name=config_name, name=name, config=config + config_name=config_name, name=instance_name, config=config ) - exec_program = EXEC_TEMPLATE.render( + benchmark_instance = BENCHMARK_INSTANCE_TEMPLATE.render( indent=" ", - instance=name, - is_profiler=True, - support_split_k=support_split_k, - problem_args=problem_args_template.render(), - ) - input_output_checks = INPUT_OUTPUT_CHECKS_TEMPLATE.render( - input_ndims=ndims, - weight_ndims=ndims, - output_ndims=ndims, - ) - op_func = src_template.render( - instances=instance, - function_name="gemm", - input_ndims=ndims, - weight_ndims=ndims, - output_ndims=ndims, - shape_eval=shape_func, - input_output_checks=input_output_checks, - exec_paths=exec_program, - output_addr_calculator=output_addr_calculator, - support_split_k=support_split_k, - extra_code=extra_code, - ) - func_call = FUNC_CALL_TEMPLATE.render( - func_name="gemm", - a_ptr="memory_pool->RequestHalfTensorByIdx(0)", - b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + instance_name=instance_name, + gemm_op=gemm_op, + gemm_op_name=op_name, + func_name=f"benchmark_{function_name}", + a_ptr="memory_pool->RequestTensorByIdx(0)", + b_ptr="memory_pool->RequestTensorByIdx(1)", has_bias=has_bias, bias_ptr=bias_ptr_arg, - c_ptr="memory_pool->RequestHalfTensorByIdx(2)", + c_ptr="memory_pool->RequestTensorByIdx(2)", + support_split_k=support_split_k, split_k="split_k", adims=adims, bdims=bdims, cdims=cdims, ) - # TODO: Render args_parse by caller. - args_parse = ( - args_parser_template - if isinstance(args_parser_template, str) - else args_parser_template.render() - ) - code = PROFILER_TEMPLATE.render( - op_func=op_func, - args_parse=args_parse, - func_call=func_call, - name=name, - tensor_decl=TENSOR_DECL_TEMPLATE.render(name=name, has_bias=has_bias), - ) - add_profiler(file_pairs, workdir, op_type, op_name, code) + instances.append(instance) + benchmark_instances.append(benchmark_instance) + # TODO: Render args_parse by caller. + args_parse = ( + args_parser_template + if isinstance(args_parser_template, str) + else args_parser_template.render() + ) + op_func = src_template.render( + is_profiler=True, + instances="\n".join(instances), + function_name=function_name, + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + output_addr_calculator=output_addr_calculator, + support_split_k=support_split_k, + extra_code=extra_code, + ) + benchmark_adims = ["a_dim" + str(i) for i in range(ndims)] + benchmark_bdims = ["b_dim" + str(i) for i in range(ndims)] + benchmark_cdims = ["c_dim" + str(i) for i in range(ndims)] + func_call = FUNC_CALL_TEMPLATE.render( + is_profiler=True, + func_name=function_name, + a_ptr="a_ptr", + b_ptr="b_ptr", + has_bias=has_bias, + bias_ptr="bias_ptr", + c_ptr="c_ptr", + split_k="split_k", + adims=benchmark_adims, + bdims=benchmark_bdims, + cdims=benchmark_cdims, + ) + tensor_decl = TENSOR_DECL_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + has_bias=has_bias, + ) + code = PROFILER_TEMPLATE.render( + op_func=op_func, + has_bias=has_bias, + has_d=has_d(func_attrs), + support_split_k=support_split_k, + args_parse=args_parse, + function_name=function_name, + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + func_call=func_call, + name=instance_name_base, + tensor_decl=tensor_decl, + benchmark_instances="\n".join(benchmark_instances), + elem_type=elem_type, + ) + # FIXME: remove file_pairs once we have make -j ready for building + # an entire graph + file_pairs = [] + add_profiler(file_pairs, workdir, op_type, profiler_filename, code) # build - build_profiler(file_pairs) + return build_profiler(file_pairs) def gen_local_dim_defs(func_attrs, indent=" "): @@ -864,22 +1049,24 @@ def gen_function_call(func_attrs, indent=" ", bias_ptr_arg=None): ) -def default_fproc_f16(*, op, a_layout, b_layout, c_layout, epiligue_name): +def default_fproc( + *, op, a_layout, b_layout, c_layout, elem_type, epiligue_name, permute_layout=None +): import copy import cutlass_lib ret = [] - data_type = cutlass_lib.library.DataType.f16 + data_type = elem_type acc_type = cutlass_lib.library.DataType.f32 # check target use fp16 acc - if "use_fp16_acc" in Target.current()._kwargs: + if "use_fp16_acc" in Target.current()._kwargs and data_type == "cutlass::half_t": if Target.current()._kwargs["use_fp16_acc"]: acc_type = cutlass_lib.library.DataType.f16 if ( - op.A.element == data_type - and op.B.element == data_type - and op.C.element == data_type + cutlass_lib.library.DataTypeTag[op.A.element] == data_type + and cutlass_lib.library.DataTypeTag[op.B.element] == data_type + and cutlass_lib.library.DataTypeTag[op.C.element] == data_type and op.accumulator_type() == acc_type and op.A.layout == a_layout and op.B.layout == b_layout @@ -890,6 +1077,10 @@ def default_fproc_f16(*, op, a_layout, b_layout, c_layout, epiligue_name): # set epilogue op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epiligue_name] op.element_epilogue = acc_type + if permute_layout is not None: + op.permute_layout = cutlass_lib.library.EpiloguePermuteLayoutName[ + permute_layout + ] # set C alignment for i in [8, 4, 2, 1]: op = copy.deepcopy(op) @@ -898,23 +1089,27 @@ def default_fproc_f16(*, op, a_layout, b_layout, c_layout, epiligue_name): return ret -def make_fproc_f16(func_attrs, layout): +def make_fproc(func_attrs, layout): """ This function sets a callback for processing the epilogue of the kernel associated with func_attrs. """ - def fproc_f16(op): + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type(func_attrs["inputs"][0]._attrs["dtype"]) + + def fproc(op): a_layout, b_layout, c_layout = layout.cutlass_lib_layouts() - return default_fproc_f16( + return default_fproc( op=op, a_layout=a_layout, b_layout=b_layout, c_layout=c_layout, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = extract_config(fproc_f16) + func_attrs["op_instance"] = extract_config(fproc) def function_filter(cfg, func_attrs, ab_alignment): diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_bias.py b/python/aitemplate/backend/cuda/gemm_universal/common_bias.py index 98d8e979c..2d4e7f05a 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/common_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/common_bias.py @@ -58,11 +58,17 @@ {{instances}} +{% if is_profiler %} +template void {{function_name}} ( - cutlass::half_t* a_ptr, - cutlass::half_t* b_ptr, - cutlass::half_t* bias_ptr, - cutlass::half_t* c_ptr, + GemmInstance& gemm_op, +{% else %} +void {{function_name}} ( +{% endif %} + void* a_ptr, + void* b_ptr, + void* bias_ptr, + void* c_ptr, uint8_t* workspace, {% if support_split_k %} int split_k, @@ -111,10 +117,10 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, + void*, + void*, uint8_t*, {% if support_split_k %} int, diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py b/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py index 843230243..bd7e437e4 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py +++ b/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py @@ -17,6 +17,7 @@ Common codegen functions for gemm_bias_activation. """ +from ...backend_spec import CUDASpec from . import common, common_bias, gemm_rcr from .layout import RCR @@ -24,23 +25,25 @@ def gemm_rcr_config(func_attrs, dtype="float16"): - common.make_fproc_f16(func_attrs, RCR) + common.make_fproc(func_attrs, RCR) def gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, problem_args_template, extra_code="", ): - gemm_rcr.common_gen_profiler( + return gemm_rcr.common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common_bias.SRC_TEMPLATE, problem_args_template, - bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", extra_code=extra_code, ) @@ -55,7 +58,17 @@ def gen_function( input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) output_ndims = len(func_attrs["output_accessors"][0].original_shapes) - problem_args = problem_args_template.render() + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + problem_args = problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) return common.gen_function( func_attrs, common_bias.SRC_TEMPLATE, diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_bias_broadcast.py b/python/aitemplate/backend/cuda/gemm_universal/common_bias_broadcast.py index 5c46b3cc5..42564bc0c 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/common_bias_broadcast.py +++ b/python/aitemplate/backend/cuda/gemm_universal/common_bias_broadcast.py @@ -22,6 +22,7 @@ import jinja2 +from ...backend_spec import CUDASpec from ...common import gemm_common from ...target import Target @@ -70,16 +71,16 @@ 1, {% endif %} {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) (a_ptr + input_a_offset), - (void*) (b_ptr + input_b_offset), - (void*) d0_ptr, + ({{elem_input_type}}*)(a_ptr) + input_a_offset, + ({{elem_input_type}}*)(b_ptr) + input_b_offset, + ({{elem_output_type}}*)(d0_ptr), {% if has_d1 %} - (void*) d1_ptr, + ({{elem_output_type}}*)(d1_ptr), {% else %} nullptr, {% endif %} - (void*) (c_ptr + output_offset), - (void*) bias_ptr, + ({{elem_output_type}}*) (c_ptr) + output_offset, + ({{elem_input_type}}*) (bias_ptr), nullptr, /*batch_stride_A*/ input_a_batch_stride, /*batch_stride_B*/ input_b_batch_stride, @@ -113,16 +114,16 @@ 1, {% endif %} {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) d0_ptr, + ({{elem_input_type}}*) a_ptr, + ({{elem_input_type}}*) b_ptr, + ({{elem_output_type}}*) d0_ptr, {% if has_d1 %} - (void*) d1_ptr, + ({{elem_output_type}}*) d1_ptr, {% else %} nullptr, {% endif %} - (void*) (c_ptr + output_offset), - (void*) bias_ptr, + ({{elem_output_type}}*) (c_ptr) + output_offset, + ({{elem_input_type}}*) bias_ptr, nullptr, /*batch_stride_A*/ 0, /*batch_stride_B*/ 0, @@ -173,15 +174,21 @@ {{instances}} +{% if is_profiler %} +template void {{function_name}} ( - cutlass::half_t* a_ptr, - cutlass::half_t* b_ptr, - cutlass::half_t* bias_ptr, - cutlass::half_t* d0_ptr, + GemmInstance& gemm_op, +{% else %} +void {{function_name}} ( +{% endif %} + void* a_ptr, + void* b_ptr, + void* bias_ptr, + void* d0_ptr, {% if has_d1 %} - cutlass::half_t* d1_ptr, + void* d1_ptr, {% endif %} - cutlass::half_t* c_ptr, + void* c_ptr, uint8_t* workspace, {% if support_split_k %} int split_k, @@ -229,14 +236,14 @@ FUNC_DECL_TEMPLATE = jinja2.Template( """ void {{func_name}}( - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, - cutlass::half_t*, + void*, + void*, + void*, + void*, {% if has_d1 %} - cutlass::half_t*, + void*, {% endif %} - cutlass::half_t*, + void*, uint8_t*, {% if support_split_k %} int, @@ -262,6 +269,9 @@ {{indent}}{ {{indent}}{{local_dim_defs}} {{indent}}{{func_name}}( +{% if is_profiler %} +{{indent}} gemm_op, +{% endif %} {{indent}} {{a_ptr}}, {{indent}} {{b_ptr}}, {{indent}} {{bias_ptr}}, @@ -270,7 +280,7 @@ {{indent}} {{d1_ptr}}, {% endif %} {{indent}} {{c_ptr}}, -{{indent}} global_workspace, +{{indent}} global_workspace_, {% if support_split_k %} {{indent}} {{split_k}}, {% endif %} @@ -313,13 +323,13 @@ // need to tune it for other devices int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 25) / ptr_max_sz))); - memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 - memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 - memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 - memory_pool->AllocateHalfTensor(c_dim1, mem_pool_sz); // bias_ptr: index 3 - memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d0 ptr: index 4 + memory_pool->AllocateTensor(a_ptr_sz, mem_pool_sz); // a_ptr: index 0 + memory_pool->AllocateTensor(b_ptr_sz, mem_pool_sz); // b_ptr: index 1 + memory_pool->AllocateTensor(c_ptr_sz, mem_pool_sz); // c_ptr: index 2 + memory_pool->AllocateTensor(c_dim1, mem_pool_sz); // bias_ptr: index 3 + memory_pool->AllocateTensor(c_ptr_sz, mem_pool_sz); // d0 ptr: index 4 {% if has_d1 %} - memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d1 ptr: index 5 + memory_pool->AllocateTensor(c_ptr_sz, mem_pool_sz); // d1 ptr: index 5 {% endif %} """ ) @@ -386,12 +396,13 @@ def gemm_bias_broadcast_instance( def gemm_bias_broadcast_config(func_attrs, layout, dtype="float16"): - common.make_fproc_f16(func_attrs, layout) + common.make_fproc(func_attrs, layout) def gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, layout, unary_op1, @@ -399,6 +410,16 @@ def gen_profiler( binary_op2, unary_op2, ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + elem_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) op_type = func_attrs["op"] support_split_k = _support_split_k(func_attrs) op_instance = func_attrs["op_instance"] @@ -412,8 +433,29 @@ def gen_profiler( indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True ) - file_pairs = [] - for op_name, op in op_instance.items(): + instance_name_base = "GemmInstance" + exec_program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=instance_name_base, + is_profiler=True, + problem_args=PROFILER_PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + support_split_k=support_split_k, + layout=layout, + has_d1=has_d1, + ), + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + ) + + function_name = "gemm" + instances = [] + benchmark_instances = [] + for instance_idx, (op_name, op) in enumerate(op_instance.items()): config = common.emit_instance( op, for_profiler=True, @@ -427,64 +469,95 @@ def gen_profiler( ), ) config_name = common.extract_config_name(config) - name = "GemmInstance" + instance_name = f"{instance_name_base}_{instance_idx}" + gemm_op = f"gemm_op_{instance_idx}" instance = common.INSTANCE_TEMPLATE.render( - config_name=config_name, name=name, config=config + config_name=config_name, name=instance_name, config=config ) - exec_program = common.EXEC_TEMPLATE.render( + benchmark_instance = common.BENCHMARK_INSTANCE_TEMPLATE.render( indent=" ", - instance=name, - is_profiler=True, - problem_args=PROFILER_PROBLEM_ARGS_TEMPLATE.render( - support_split_k=support_split_k, layout=layout, has_d1=has_d1 - ), - ) - input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( - input_ndims=ndims, - weight_ndims=ndims, - output_ndims=ndims, - ) - op_func = SRC_TEMPLATE.render( - instances=instance, - function_name="gemm", - input_ndims=ndims, - weight_ndims=ndims, - shape_eval=shape_func, - input_output_checks=input_output_checks, - exec_paths=exec_program, - output_addr_calculator=common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( - stride_dim="N" - ), - support_split_k=support_split_k, - has_d1=has_d1, - ) - func_call = FUNC_CALL_TEMPLATE.render( - func_name="gemm", - a_ptr="memory_pool->RequestHalfTensorByIdx(0)", - b_ptr="memory_pool->RequestHalfTensorByIdx(1)", - c_ptr="memory_pool->RequestHalfTensorByIdx(2)", - d0_ptr="memory_pool->RequestHalfTensorByIdx(4)", - d1_ptr="memory_pool->RequestHalfTensorByIdx(5)", - bias_ptr="memory_pool->RequestHalfTensorByIdx(3)", + instance_name=instance_name, + gemm_op=gemm_op, + gemm_op_name=op_name, + func_name=f"benchmark_{function_name}", + a_ptr="memory_pool->RequestTensorByIdx(0)", + b_ptr="memory_pool->RequestTensorByIdx(1)", + c_ptr="memory_pool->RequestTensorByIdx(2)", + d_ptr="memory_pool->RequestTensorByIdx(4)", + d1_ptr="memory_pool->RequestTensorByIdx(5)", + bias_ptr="memory_pool->RequestTensorByIdx(3)", adims=adims, bdims=bdims, cdims=cdims, support_split_k=support_split_k, split_k="split_k", + has_bias=True, + has_d=True, has_d1=has_d1, ) - code = common.PROFILER_TEMPLATE.render( - op_func=op_func, - args_parse=ARGS_PARSER_TEMPLATE.render( - layout=layout, support_split_k=support_split_k - ), - func_call=func_call, - name=name, - tensor_decl=TENSOR_DECL_TEMPLATE.render(name=name, has_d1=has_d1), - ) - common.add_profiler(file_pairs, workdir, op_type, op_name, code) + instances.append(instance) + benchmark_instances.append(benchmark_instance) + op_func = SRC_TEMPLATE.render( + is_profiler=True, + instances="\n".join(instances), + function_name=function_name, + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + input_ndims=ndims, + weight_ndims=ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + output_addr_calculator=common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N" + ), + support_split_k=support_split_k, + has_d1=has_d1, + ) + benchmark_adims = ["a_dim" + str(i) for i in range(ndims)] + benchmark_bdims = ["b_dim" + str(i) for i in range(ndims)] + benchmark_cdims = ["c_dim" + str(i) for i in range(ndims)] + func_call = FUNC_CALL_TEMPLATE.render( + is_profiler=True, + func_name="gemm", + a_ptr="a_ptr", + b_ptr="b_ptr", + c_ptr="c_ptr", + d0_ptr="d_ptr", + d1_ptr="d1_ptr", + bias_ptr="bias_ptr", + adims=benchmark_adims, + bdims=benchmark_bdims, + cdims=benchmark_cdims, + support_split_k=support_split_k, + split_k="split_k", + has_d1=has_d1, + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + has_bias=True, + has_d=True, + has_d1=has_d1, + support_split_k=support_split_k, + args_parse=ARGS_PARSER_TEMPLATE.render( + layout=layout, support_split_k=support_split_k + ), + function_name=function_name, + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + func_call=func_call, + name=instance_name_base, + tensor_decl=TENSOR_DECL_TEMPLATE.render(has_d1=has_d1), + benchmark_instances="\n".join(benchmark_instances), + elem_type=elem_type, + ) + # FIXME: remove file_pairs once we have make -j ready for building + # an entire graph + file_pairs = [] + common.add_profiler(file_pairs, workdir, op_type, profiler_filename, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) def gen_function( @@ -497,6 +570,13 @@ def gen_function( binary_op2, unary_op2, ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) input_addr_calculator = gemm_rcr.get_input_addr_calculator(func_attrs) input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) @@ -504,7 +584,11 @@ def gen_function( support_split_k = _support_split_k(func_attrs) has_d1 = common.has_d1(func_attrs) problem_args = PROBLEM_ARGS_TEMPLATE.render( - layout=layout, support_split_k=support_split_k, has_d1=has_d1 + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + layout=layout, + support_split_k=support_split_k, + has_d1=has_d1, ) return common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_no_bias.py b/python/aitemplate/backend/cuda/gemm_universal/common_no_bias.py new file mode 100644 index 000000000..8c1e80cc3 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/common_no_bias.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Common codegen functions for gemm_activation but use nullptr for bias. +""" + +import jinja2 + + +SRC_TEMPLATE = jinja2.Template( + """ +#include +#include +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/device_memory.h" + +{{extra_code}} + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + std::cerr << msg << std::endl; \\ + throw std::runtime_error(msg); \\ + } \\ + } + +{{instances}} + +{% if is_profiler %} +template +void {{function_name}} ( + GemmInstance& gemm_op, +{% else %} +void {{function_name}} ( +{% endif %} + void* a_ptr, + void* b_ptr, + void* bias_ptr, + void* c_ptr, + uint8_t* workspace, +{% if support_split_k %} + int split_k, +{% endif %} +{% for idx in range(input_ndims) %} + int64_t* a_dim{{idx}}, +{% endfor %} +{% for idx in range(weight_ndims) %} + int64_t* b_dim{{idx}}, +{% endfor %} +{% for idx in range(input_ndims) %} + int64_t* c_dim{{idx}}, +{% endfor %} + cudaStream_t stream + ) { + {{shape_eval}} + {{input_addr_calculator}} + {{output_addr_calculator}} + {{extra_shape}} + {{input_output_checks}} + + if (bias_ptr) { + throw std::runtime_error("bias_ptr is not null!"); + } + + {{exec_paths}} + {% for idx in range(input_ndims) %} + std::cout << "input_ndims{{idx}}: " << *a_dim{{idx}} << std::endl; + {% endfor %} + {% for idx in range(weight_ndims) %} + std::cout << "weight_ndims{{idx}}: " << *b_dim{{idx}} << std::endl; + {% endfor %} + {% for idx in range(input_ndims) %} + std::cout << "output_ndims{{idx}}: " << *c_dim{{idx}} << std::endl; + {% endfor %} + throw std::runtime_error( + "Unsupported workload for this {{function_name}} specialization." + ); +} +""", + trim_blocks=True, + lstrip_blocks=True, +) diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_permute.py b/python/aitemplate/backend/cuda/gemm_universal/common_permute.py index 2f3f1e903..378911608 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/common_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/common_permute.py @@ -17,13 +17,13 @@ """ import re -from collections import OrderedDict +from functools import partial from hashlib import sha1 import jinja2 +from ...backend_spec import CUDASpec from ...common import gemm_common -from ...target import Target from ..gemm_universal import common # pylint: disable=C0301,C0415,R1705 @@ -63,6 +63,9 @@ def kernel_name(op, func_attrs): if len(shape) == 1: perm_type = "perm4d" perm_shape = f"{shape[0]}" + elif len(shape) == 2: + perm_type = "perm4d" + perm_shape = f"{shape[0]}_{shape[1]}" elif len(shape) == 3: perm_type = "perm5d" perm_shape = f"{shape[0]}_{shape[1]}_{shape[2]}" @@ -83,63 +86,8 @@ def kernel_name(op, func_attrs): return name.replace("\n", "") -def default_fproc_f16( - *, op, a_layout, b_layout, c_layout, epiligue_name, permute_layout -): - """Generates new op_instances by adding alignment info, permute_layout, etc.""" - import copy - - import cutlass_lib - - ret = [] - data_type = cutlass_lib.library.DataType.f16 - acc_type = cutlass_lib.library.DataType.f32 - # check target use fp16 acc - if "use_fp16_acc" in Target.current()._kwargs: - if Target.current()._kwargs["use_fp16_acc"]: - acc_type = cutlass_lib.library.DataType.f16 - if ( - op.A.element == data_type - and op.B.element == data_type - and op.C.element == data_type - and op.accumulator_type() == acc_type - and op.A.layout == a_layout - and op.B.layout == b_layout - ): - op = copy.deepcopy(op) - # set output major - op.C.layout = c_layout - # set epilogue - op.epilogue_functor = cutlass_lib.library.EpilogueFunctorName[epiligue_name] - op.element_epilogue = acc_type - op.permute_layout = cutlass_lib.library.EpiloguePermuteLayoutName[ - permute_layout - ] - # set C alignment - for i in [8, 4, 2, 1]: - op = copy.deepcopy(op) - op.C.alignment = i - ret.append(op) - return ret - - def extract_config(f_proc_op, func_attrs): - import cutlass_lib - - op_kind = cutlass_lib.library.OperationKind.Gemm - gemm_kind = cutlass_lib.library.GemmKind.Universal - gemm_ops = OrderedDict() - extract_ops = list(Target.current()._operators[op_kind].items()) - - for _, value in extract_ops: - op = value[0] - if op.gemm_kind == gemm_kind: - ret = f_proc_op(op) - if len(ret) > 0: - for op_inst in ret: - key = kernel_name(op_inst, func_attrs) - gemm_ops[key] = op_inst - return gemm_ops + return common.extract_config(f_proc_op, partial(kernel_name, func_attrs=func_attrs)) def gemm_permute_instance(op_def, func_attrs, for_profiler): @@ -262,6 +210,7 @@ def gen_function( def gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args_template, @@ -272,6 +221,16 @@ def gen_profiler( bias_ptr_arg=None, extra_code="", ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + elem_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) op_type = func_attrs["op"] op_instance = func_attrs["op_instance"] @@ -283,69 +242,109 @@ def gen_profiler( indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True ) - file_pairs = [] has_bias = bias_ptr_arg is not None - for op_name, op in op_instance.items(): + instance_name_base = "GemmInstance" + exec_program = common.EXEC_TEMPLATE.render( + indent=" ", + instance=instance_name_base, + is_profiler=True, + support_split_k=support_split_k, + problem_args=problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ), + ) + input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + ) + + function_name = "gemm" + instances = [] + benchmark_instances = [] + for instance_idx, (op_name, op) in enumerate(op_instance.items()): config = emit_instance( op, for_profiler=True, emit_kernel=emit_kernel, func_attrs=func_attrs ) config_name = common.extract_config_name(config) - name = "GemmInstance" + instance_name = f"{instance_name_base}_{instance_idx}" + gemm_op = f"gemm_op_{instance_idx}" instance = common.INSTANCE_TEMPLATE.render( - config_name=config_name, name=name, config=config + config_name=config_name, name=instance_name, config=config ) - exec_program = common.EXEC_TEMPLATE.render( + benchmark_instance = common.BENCHMARK_INSTANCE_TEMPLATE.render( indent=" ", - instance=name, - is_profiler=True, - support_split_k=support_split_k, - problem_args=problem_args_template.render(), - ) - input_output_checks = common.INPUT_OUTPUT_CHECKS_TEMPLATE.render( - input_ndims=ndims, - weight_ndims=ndims, - output_ndims=ndims, - ) - op_func = src_template.render( - instances=instance, - function_name="gemm", - input_ndims=2, - weight_ndims=2, - output_ndims=2, - shape_eval=shape_func, - input_output_checks=input_output_checks, - exec_paths=exec_program, - output_addr_calculator=output_addr_calculator, - support_split_k=support_split_k, - extra_code=extra_code, - ) - func_call = common.FUNC_CALL_TEMPLATE.render( - func_name="gemm", - a_ptr="memory_pool->RequestHalfTensorByIdx(0)", - b_ptr="memory_pool->RequestHalfTensorByIdx(1)", + instance_name=instance_name, + gemm_op=gemm_op, + gemm_op_name=op_name, + func_name=f"benchmark_{function_name}", + a_ptr="memory_pool->RequestTensorByIdx(0)", + b_ptr="memory_pool->RequestTensorByIdx(1)", has_bias=has_bias, bias_ptr=bias_ptr_arg, - c_ptr="memory_pool->RequestHalfTensorByIdx(2)", + c_ptr="memory_pool->RequestTensorByIdx(2)", + support_split_k=support_split_k, split_k="split_k", adims=adims, bdims=bdims, cdims=cdims, ) - # TODO: Render args_parse by caller. - args_parse = ( - args_parser_template - if isinstance(args_parser_template, str) - else args_parser_template.render() - ) - code = common.PROFILER_TEMPLATE.render( - op_func=op_func, - args_parse=args_parse, - func_call=func_call, - name=name, - tensor_decl=common.TENSOR_DECL_TEMPLATE.render( - name=name, has_bias=has_bias - ), - ) - common.add_profiler(file_pairs, workdir, op_type, op_name, code) + instances.append(instance) + benchmark_instances.append(benchmark_instance) + op_func = src_template.render( + is_profiler=True, + instances="\n".join(instances), + function_name=function_name, + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + shape_eval=shape_func, + input_output_checks=input_output_checks, + exec_paths=exec_program, + output_addr_calculator=output_addr_calculator, + support_split_k=support_split_k, + extra_code=extra_code, + ) + benchmark_adims = ["a_dim" + str(i) for i in range(ndims)] + benchmark_bdims = ["b_dim" + str(i) for i in range(ndims)] + benchmark_cdims = ["c_dim" + str(i) for i in range(ndims)] + func_call = common.FUNC_CALL_TEMPLATE.render( + is_profiler=True, + func_name=function_name, + a_ptr="a_ptr", + b_ptr="b_ptr", + has_bias=has_bias, + bias_ptr="bias_ptr", + c_ptr="c_ptr", + split_k="split_k", + adims=benchmark_adims, + bdims=benchmark_bdims, + cdims=benchmark_cdims, + ) + # TODO: Render args_parse by caller. + args_parse = ( + args_parser_template + if isinstance(args_parser_template, str) + else args_parser_template.render() + ) + code = common.PROFILER_TEMPLATE.render( + op_func=op_func, + has_bias=has_bias, + support_split_k=support_split_k, + args_parse=args_parse, + function_name=function_name, + input_ndims=ndims, + weight_ndims=ndims, + output_ndims=ndims, + func_call=func_call, + tensor_decl=common.TENSOR_DECL_TEMPLATE.render(has_bias=has_bias), + benchmark_instances="\n".join(benchmark_instances), + elem_type=elem_type, + ) + # FIXME: remove file_pairs once we have make -j ready for building + # an entire graph + file_pairs = [] + common.add_profiler(file_pairs, workdir, op_type, profiler_filename, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py index 0fb211cb0..44c85125c 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py @@ -20,6 +20,8 @@ import jinja2 from ... import registry + +from ...backend_spec import CUDASpec from . import common from .layout import RCR @@ -49,10 +51,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, - (void*) (a_ptr + input_a_offset), - (void*) (b_ptr + input_b_offset), - (void*) (c_ptr + output_offset), - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr) + input_a_offset, + ({{elem_input_type}}*)(b_ptr) + input_b_offset, + ({{elem_output_type}}*)(c_ptr) + output_offset, + ({{elem_output_type}}*)(c_ptr) + output_offset, input_a_batch_stride, input_b_batch_stride, /*output_batch_stride*/ M * N, @@ -72,10 +74,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) c_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_output_type}}*)(c_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, M * N, @@ -90,12 +92,13 @@ @registry.reg("cuda.gemm_rcr.config") def gemm_rcr_config(func_attrs, dtype="float16"): - common.make_fproc_f16(func_attrs, RCR) + common.make_fproc(func_attrs, RCR) def common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args_template, @@ -105,9 +108,10 @@ def common_gen_profiler( output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( stride_dim="*b_dim0" ) - common.gen_profiler( + return common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args_template, @@ -120,10 +124,11 @@ def common_gen_profiler( @registry.reg("cuda.gemm_rcr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, PROFILER_PROBLEM_ARGS_TEMPLATE, @@ -172,7 +177,17 @@ def gen_function( input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) output_ndims = len(func_attrs["output_accessors"][0].original_shapes) - problem_args = PROBLEM_ARGS_TEMPLATE.render() + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) return common.gen_function( func_attrs, common.SRC_TEMPLATE, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py index f54c0ed2c..7c06c7408 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py @@ -20,6 +20,8 @@ import jinja2 from ... import registry + +from ...backend_spec import CUDASpec from . import common, common_bias, gemm_rcr # pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 @@ -32,10 +34,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) (a_ptr + input_a_offset), - (void*) (b_ptr + input_b_offset), - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr) + input_a_offset, + ({{elem_input_type}}*)(b_ptr) + input_b_offset, + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, input_a_batch_stride, input_b_batch_stride, /*bias_batch_stride*/ N, @@ -55,10 +57,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, N, @@ -77,14 +79,15 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - gemm_rcr.common_gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return gemm_rcr.common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common_bias.SRC_TEMPLATE, PROFILER_PROBLEM_ARGS_TEMPLATE, - bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", ) @@ -98,7 +101,17 @@ def gen_function( input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) output_ndims = len(func_attrs["output_accessors"][0].original_shapes) - problem_args = PROBLEM_ARGS_TEMPLATE.render() + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) return common.gen_function( func_attrs, common_bias.SRC_TEMPLATE, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py index c2fc67191..c556485f1 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_add.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py index 56511dbc1..bd2988abf 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_add_add.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py index f823baab2..5d262712e 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_add_relu.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_add_add_relu.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py index bd4f7da4b..212b01a74 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_add_relu.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_add_relu.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py index f55e21cd8..12af54f6a 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py @@ -66,10 +66,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, N, @@ -88,10 +88,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_fast_gelu.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_bias_activation.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, PROBLEM_ARGS_TEMPLATE, extra_code=EXTRA_CODE.render(), diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py index d16d769a1..b4617b9d6 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py @@ -30,10 +30,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, N, @@ -52,10 +52,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_gelu.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_bias_activation.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, PROBLEM_ARGS_TEMPLATE, ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py index 6c22e1e3a..a0952d345 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py @@ -30,10 +30,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, N, @@ -52,10 +52,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_hardswish.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_bias_activation.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, PROBLEM_ARGS_TEMPLATE, ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py index f2049abef..1b2dea303 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_mul.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py index 55400a029..12bce07ae 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_add.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_mul_add.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py index 3d5abf306..c8be43f28 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_mul_tanh.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_mul_tanh.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_permute.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_permute.py index 2a4c75cbe..6abdcc977 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_permute.py @@ -17,6 +17,7 @@ """ from ... import registry +from ...backend_spec import CUDASpec from ..gemm_universal import common from . import common_bias, common_permute, gemm_rcr_bias, gemm_rcr_permute @@ -31,14 +32,15 @@ def gemm_rcr_bias_permute_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_permute.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return gemm_rcr_permute.common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common_bias.SRC_TEMPLATE, PROBLEM_ARGS_TEMPLATE, - bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", extra_code=common_permute.EXTRA_CODE.render(), ) @@ -50,10 +52,23 @@ def gen_function( dim_info_dict, problem_args_template=None, ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) if problem_args_template is None: - problem_args = PROBLEM_ARGS_TEMPLATE.render() + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) else: - problem_args = problem_args_template.render() + problem_args = problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) output_ndims = len(func_attrs["output_accessors"][0].original_shapes) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_relu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_relu.py index 3a5940e7a..eae96241c 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_relu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_relu.py @@ -31,10 +31,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, N, @@ -53,10 +53,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_relu.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_bias_activation.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, PROBLEM_ARGS_TEMPLATE, ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py index 719efbfa2..e8ea6a976 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py @@ -31,10 +31,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, N, @@ -53,10 +53,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_sigmoid.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_bias_activation.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, PROBLEM_ARGS_TEMPLATE, ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py index b3b306f38..2828d379d 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_sigmoid_mul.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py index 66cad13c4..b3d721d6c 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid_mul_tanh.py @@ -36,10 +36,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_sigmoid_mul_tanh.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - common_bias_broadcast.gen_profiler( +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_broadcast.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, RCR, UNARY_OP1, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py index 688c9daf3..e4c082580 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py @@ -31,10 +31,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, N, @@ -53,10 +53,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_swish.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_bias_activation.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, PROBLEM_ARGS_TEMPLATE, ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py index 8a11c966f..934c9a1c0 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py @@ -66,10 +66,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) bias_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_input_type}}*)(bias_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, N, @@ -88,10 +88,11 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.gemm_rcr_bias_tanh.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_bias_activation.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, PROBLEM_ARGS_TEMPLATE, extra_code=EXTRA_CODE.render(), diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_fast_gelu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_fast_gelu.py new file mode 100644 index 000000000..791f3e300 --- /dev/null +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_fast_gelu.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +GEMM Specialization for C = fast_gelu(GeMM(A, B) + bias) +where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] +""" +import jinja2 + +from ... import registry + +from ...backend_spec import CUDASpec +from . import common, common_bias_activation, common_no_bias + +# pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 + +EXTRA_CODE = jinja2.Template( + """ +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/constants.h" +#include "cutlass/complex.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/functional.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = ScaleType::Default, ///< Control Alpha and Beta scaling + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +using LinearCombinationFastGELU = LinearCombinationGeneric; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +""" +) + + +PROBLEM_ARGS_TEMPLATE = jinja2.Template( + """ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + ({{elem_input_type}}*) a_ptr, + ({{elem_input_type}}*) b_ptr, + nullptr, + ({{elem_output_type}}*) (c_ptr) + output_offset, + M * K, + N * K, + N, + M * N, + K, + K, + 0, + output_stride +""" +) + + +@registry.reg("cuda.gemm_rcr_fast_gelu.config") +def gemm_rcr_config(func_attrs, dtype="float16"): + return common_bias_activation.gemm_rcr_config(func_attrs, dtype) + + +@registry.reg("cuda.gemm_rcr_fast_gelu.gen_profiler") +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return common_bias_activation.gen_profiler( + func_attrs, + workdir, + profiler_filename, + dim_info_dict, + PROBLEM_ARGS_TEMPLATE, + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_fast_gelu.gen_function") +def gen_function( + func_attrs, + exec_cond_template, + dim_info_dict, +): + input_ndims = len(func_attrs["input_accessors"][0].original_shapes) + weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) + output_ndims = len(func_attrs["output_accessors"][0].original_shapes) + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) + return common.gen_function( + func_attrs, + common_no_bias.SRC_TEMPLATE, + exec_cond_template, + problem_args, + input_ndims, + weight_ndims, + output_ndims, + dim_info_dict, + support_split_k=True, + output_addr_calculator=common.OUTPUT_ADDR_CALCULATOR.render( + stride_dim="N", + output_accessor=func_attrs["output_accessors"][0], + ), + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("cuda.gemm_rcr_fast_gelu.func_decl") +def gen_function_decl(func_attrs): + return common_bias_activation.gen_function_decl(func_attrs) + + +@registry.reg("cuda.gemm_rcr_fast_gelu.func_call") +def gen_function_call(func_attrs, indent=" "): + return common.gen_function_call(func_attrs, indent, bias_ptr_arg="nullptr") + + +@registry.reg("cuda.gemm_rcr_fast_gelu.filter") +def function_filter(cfg, func_attrs, ab_alignment): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + ab_alignment: + Input alignments. + + Returns + ------- + bool + If input cfg should be filtered. + """ + return common.function_filter(cfg, func_attrs, ab_alignment) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py index f2851db12..b5f1cc9da 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py @@ -20,8 +20,8 @@ import jinja2 from ... import registry -from ..gemm_universal import common -from . import common_permute +from ...backend_spec import CUDASpec +from . import common, common_permute # pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 @@ -49,10 +49,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) c_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_output_type}}*)(c_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, M * N, @@ -67,24 +67,33 @@ @registry.reg("cuda.gemm_rcr_permute.config") def gemm_rcr_permute_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common_permute.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.RowMajor, b_layout=cutlass_lib.library.LayoutType.ColumnMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], permute_layout=func_attrs["layout"], ) - func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + func_attrs["op_instance"] = common_permute.extract_config(fproc, func_attrs) def common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args_template, @@ -94,9 +103,10 @@ def common_gen_profiler( output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( stride_dim="*b_dim0" ) - common_permute.gen_profiler( + return common_permute.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args_template, @@ -110,10 +120,11 @@ def common_gen_profiler( @registry.reg("cuda.gemm_rcr_permute.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, PROBLEM_ARGS_TEMPLATE, @@ -128,10 +139,24 @@ def gen_function( dim_info_dict, problem_args_template=None, ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + if problem_args_template is None: - problem_args = PROBLEM_ARGS_TEMPLATE.render() + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) else: - problem_args = problem_args_template.render() + problem_args = problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) output_ndims = len(func_attrs["output_accessors"][0].original_shapes) diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py index 0a3d109d6..90654c06f 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py @@ -20,6 +20,8 @@ import jinja2 from ... import registry + +from ...backend_spec import CUDASpec from . import common # pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 @@ -47,10 +49,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) c_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_output_type}}*)(c_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, M * N, @@ -65,28 +67,37 @@ @registry.reg("cuda.gemm_rrr.config") def gemm_rrr_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.RowMajor, b_layout=cutlass_lib.library.LayoutType.RowMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = common.extract_config(fproc_f16) + func_attrs["op_instance"] = common.extract_config(fproc) @registry.reg("cuda.gemm_rrr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( stride_dim="N" ) - common.gen_profiler( + return common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, PROBLEM_ARGS_TEMPLATE, @@ -105,7 +116,17 @@ def gen_function( input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) output_ndims = len(func_attrs["output_accessors"][0].original_shapes) - problem_args = PROBLEM_ARGS_TEMPLATE.render() + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) return common.gen_function( func_attrs, common.SRC_TEMPLATE, diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py index 8653efab1..4b7ced1ea 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py @@ -20,8 +20,9 @@ import jinja2 from ... import registry -from ..gemm_universal import common -from . import common_permute + +from ...backend_spec import CUDASpec +from . import common, common_permute # pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 @@ -49,10 +50,10 @@ {M, N, K}, split_k, {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, - (void*) a_ptr, - (void*) b_ptr, - (void*) c_ptr, - (void*) (c_ptr + output_offset), + ({{elem_input_type}}*)(a_ptr), + ({{elem_input_type}}*)(b_ptr), + ({{elem_output_type}}*)(c_ptr), + ({{elem_output_type}}*)(c_ptr) + output_offset, M * K, N * K, M * N, @@ -67,24 +68,33 @@ @registry.reg("cuda.gemm_rrr_permute.config") def gemm_rrr_permute_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common_permute.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.RowMajor, b_layout=cutlass_lib.library.LayoutType.RowMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], permute_layout=func_attrs["layout"], ) - func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + func_attrs["op_instance"] = common_permute.extract_config(fproc, func_attrs) def common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args_template, @@ -94,9 +104,10 @@ def common_gen_profiler( output_addr_calculator = common.DEFAULT_OUTPUT_ADDR_CALCULATOR.render( stride_dim="N" ) - common_permute.gen_profiler( + return common_permute.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, src_template, problem_args_template, @@ -110,10 +121,11 @@ def common_gen_profiler( @registry.reg("cuda.gemm_rrr_permute.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): return common_gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, PROBLEM_ARGS_TEMPLATE, @@ -128,10 +140,23 @@ def gen_function( dim_info_dict, problem_args_template=None, ): + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) if problem_args_template is None: - problem_args = PROBLEM_ARGS_TEMPLATE.render() + problem_args = PROBLEM_ARGS_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) else: - problem_args = problem_args_template.render() + problem_args = problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_common.py b/python/aitemplate/backend/cuda/gemm_universal/group_common.py index 6568b3c4f..1185ab1ab 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_common.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_common.py @@ -21,6 +21,7 @@ import jinja2 +from ...backend_spec import CUDASpec from ...common import tensor_accessor_codegen from . import common @@ -77,13 +78,13 @@ {{indent}} int, {{indent}} int64_t*, {{indent}} int, -{{indent}} cutlass::half_t*, +{{indent}} void*, {% for i in range(groups) %} -{{indent}} cutlass::half_t*, -{{indent}} cutlass::half_t*, -{{indent}} cutlass::half_t*, +{{indent}} void*, +{{indent}} void*, +{{indent}} void*, {% if has_bias %} -{{indent}} cutlass::half_t*, +{{indent}} void*, {% endif %} {% endfor %} {{indent}} uint8_t*, @@ -104,8 +105,11 @@ FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{{func_name}}( -{{indent}} device_properties.sharedMemPerMultiprocessor, -{{indent}} device_properties.multiProcessorCount, +{% if is_profiler %} +{{indent}} gemm_op, +{% endif %} +{{indent}} device_properties_.sharedMemPerMultiprocessor, +{{indent}} device_properties_.multiProcessorCount, {{indent}} &{{func_name}}_state, {{indent}} {{problem_count}}, {{indent}} {{device_args}}, @@ -117,7 +121,7 @@ {{indent}} {{operand[3]}}, {% endif %} {% endfor %} -{{indent}} global_workspace, +{{indent}} global_workspace_, {% for operand_dim in group_operand_dims %} {{indent}} {{operand_dim[0]}}, {{indent}} {{operand_dim[1]}}, @@ -160,22 +164,25 @@ } \\ } -{{instance}} +{{instances}} {% endif %} -{{indent}}template +{{indent}}template {{indent}}void {{func_name}}_adapter( +{%if is_profiler %} + GemmInstance& gemm_op, +{% endif %} int sharedMemPerMultiprocessor, int multiProcessorCount, uint8_t* workspace, int problem_count, cutlass::gemm::GemmCoord* problem_sizes_device, - cutlass::half_t **ptr_A, - cutlass::half_t **ptr_B, - cutlass::half_t **ptr_C, + void **ptr_A, + void **ptr_B, + void **ptr_C, {% if has_bias %} - cutlass::half_t **ptr_bias, + void **ptr_bias, {% endif %} int64_t* lda, int64_t* ldb, @@ -199,6 +206,9 @@ ADAPTER_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{{func_name}}_adapter<{{instance}}>( +{% if is_profiler %} + gemm_op, +{% endif %} {{sharedMemPerMultiprocessor}}, {{multiProcessorCount}}, {{workspace}}, @@ -225,11 +235,50 @@ ) +BENCHMARK_INSTANCE_TEMPLATE = jinja2.Template( + """ +{{indent}}{ +{{indent}} +{{indent}}{{instance_name}} {{gemm_op}}; +{{indent}}const char *gemm_op_name = "{{gemm_op_name}}"; +{{indent}}int ret = {{func_name}}_adapter( +{{indent}} {{gemm_op}}, +{{indent}} gemm_op_name, +{{indent}} {{sharedMemPerMultiprocessor}}, +{{indent}} {{multiProcessorCount}}, +{{indent}} {{workspace}}, +{{indent}} {{problem_count}}, +{{indent}} {{problem_sizes_device}}, +{{indent}} (void**)({{ptr_A}}), +{{indent}} (void**)({{ptr_B}}), +{{indent}} (void**)({{ptr_C}}), +{% if has_bias %} +{{indent}} (void**)({{ptr_bias}}), +{% endif %} +{{indent}} {{lda}}, +{{indent}} {{ldb}}, +{{indent}} {{ldc}}, +{% if has_bias %} +{{indent}} {{ldd}}, +{% endif %} +{{indent}} {{instance_name}}::maximum_active_blocks(), +{{indent}} stream +{{indent}} ); +{{indent}}if (ret != 0) +{{indent}} return ret; +{{indent}} +{{indent}}} +""", + trim_blocks=True, + lstrip_blocks=True, +) + + EXEC_TEMPLATE = jinja2.Template( """ // TODO: cast to right dtype -{{indent}}using ElementComputeEpilogue = typename GEMMKind::ElementAccumulator; -{{indent}}// int smem_size = int(sizeof(typename GEMMKind::GemmKernel::SharedStorage)); +{{indent}}using ElementComputeEpilogue = typename {{instance}}::ElementAccumulator; +{{indent}}// int smem_size = int(sizeof(typename {{instance}}::GemmKernel::SharedStorage)); {{indent}}// int occupancy = std::min(2, int(sharedMemPerMultiprocessor / smem_size)); {{indent}}int threadblock_count = multiProcessorCount * occupancy; {{indent}}// Early exit @@ -240,18 +289,19 @@ {{indent}}} -{{indent}}typename GEMMKind::Arguments arguments{ +{{indent}}typename {{instance}}::Arguments arguments{ {{problem_args}} {{indent}}}; -{{indent}}GEMMKind gemm_op; {% if is_profiler %} {{indent}}// Debug BGM: https://www.youtube.com/watch?v=rRwxfYlgG-M {{indent}}size_t workspace_size = gemm_op.get_workspace_size(arguments); {{indent}}cutlass::device_memory::allocation local_workspace(workspace_size); {{indent}}workspace = local_workspace.get(); {{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size; +{% else %} +{{indent}}{{instance}} gemm_op; {% endif %} {{indent}}// TODO: cutlass bug here {{indent}}// auto status = gemm_op.can_implement(arguments); @@ -310,21 +360,27 @@ {{func_adapter}} +{% if is_profiler %} +template +void {{function_name}} ( + GemmInstance& gemm_op, +{% else %} void {{function_name}} ( +{% endif %} int sharedMemPerMultiprocessor, int multiProcessorCount, int64_t* func_state, int problem_count, - cutlass::half_t* device_args, + void* device_args, {% for operand in group_operands %} - cutlass::half_t* {{operand[0]}}, - cutlass::half_t* {{operand[1]}}, - cutlass::half_t* {{operand[2]}}, + void* {{operand[0]}}, + void* {{operand[1]}}, + void* {{operand[2]}}, {% if has_bias %} - cutlass::half_t* {{operand[3]}}, + void* {{operand[3]}}, {% endif %} {% endfor %} - uint8_t* global_workspace, + uint8_t* global_workspace_, {% for operand_dim in group_operand_dims %} int64_t* {{operand_dim[0]}}, int64_t* {{operand_dim[1]}}, @@ -369,7 +425,7 @@ {% endfor %} - uint8_t* arg_ptr = (uint8_t*) device_args; + void* arg_ptr = device_args; // problem_sizes_device: N * GemmCoord -> N * 3 * sizeof(int64_t) -> 32 * N // ptrA/B/C/D: N * 8 for each // lda/b/c/d: N * 8 for each @@ -380,14 +436,14 @@ (cutlass::gemm::GemmCoord*)(arg_ptr + offset); offset += 32 * problem_count; - auto ptr_A = (cutlass::half_t**)(arg_ptr + offset); + auto ptr_A = (void**)(arg_ptr + offset); offset += 8 * problem_count; - auto ptr_B = (cutlass::half_t**)(arg_ptr + offset); + auto ptr_B = (void**)(arg_ptr + offset); offset += 8 * problem_count; - auto ptr_C = (cutlass::half_t**)(arg_ptr + offset); + auto ptr_C = (void**)(arg_ptr + offset); offset += 8 * problem_count; {% if has_bias %} - auto ptr_bias = (cutlass::half_t**)(arg_ptr + offset); + auto ptr_bias = (void**)(arg_ptr + offset); offset += 8 * problem_count; {% endif %} @@ -405,11 +461,11 @@ if (*func_state != GROUP_0_AM) { // need update std::vector problem_sizes; - std::vector ptr_A_host; - std::vector ptr_B_host; - std::vector ptr_C_host; + std::vector ptr_A_host; + std::vector ptr_B_host; + std::vector ptr_C_host; {% if has_bias %} - std::vector ptr_bias_host; + std::vector ptr_bias_host; {% endif %} std::vector lda_host; std::vector ldb_host; @@ -419,11 +475,11 @@ {% endif %} {% for operand in group_operands %} - ptr_A_host.push_back({{operand[0]}} + input_a_offset_{{loop.index0}}); - ptr_B_host.push_back({{operand[1]}}); - ptr_C_host.push_back({{operand[2]}} + output_offset_{{loop.index0}}); + ptr_A_host.push_back(({{elem_input_type}}*)({{operand[0]}}) + input_a_offset_{{loop.index0}}); + ptr_B_host.push_back(({{elem_input_type}}*)({{operand[1]}})); + ptr_C_host.push_back(({{elem_output_type}}*)({{operand[2]}}) + output_offset_{{loop.index0}}); {% if has_bias %} - ptr_bias_host.push_back({{operand[3]}}); + ptr_bias_host.push_back(({{elem_input_type}}*)({{operand[3]}})); {% endif %} {% endfor %} @@ -514,6 +570,9 @@ TENSOR_DECL_TEMPLATE = jinja2.Template( """ + using ElementOutput = {{elem_output_type}}; + using ElementInputA = {{elem_input_type}}; + using ElementInputB = {{elem_input_type}}; cutlass::DeviceAllocation blob_A; cutlass::DeviceAllocation blob_B; cutlass::DeviceAllocation blob_C; @@ -733,6 +792,7 @@ def group_gemm_instance(op_def: str, func_attrs: Dict[str, Any], for_profiler: b def gen_profiler( func_attrs, workdir, + profiler_filename, shape_template, problem_args_template, has_bias=False, @@ -740,9 +800,31 @@ def gen_profiler( ): op_type = func_attrs["op"] op_instance = func_attrs["op_instance"] + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + elem_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) - file_pairs = [] - for op_name, op in op_instance.items(): + instance_name_base = "GemmInstance" + exec_program = EXEC_TEMPLATE.render( + indent=" ", + instance=instance_name_base, + is_profiler=True, + problem_args=problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ), + ) + + instances = [] + benchmark_instances = [] + for instance_idx, (op_name, op) in enumerate(op_instance.items()): config = common.emit_instance( op, for_profiler=True, @@ -750,29 +832,20 @@ def gen_profiler( emit_kernel=True, ) config_name = common.extract_config_name(config) - name = "GemmInstance" + instance_name = f"{instance_name_base}_{instance_idx}" + gemm_op = f"gemm_op_{instance_idx}" instance = INSTANCE_TEMPLATE.render( - config_name=config_name, name=name, config=config + config_name=config_name, name=instance_name, config=config ) - - # instance = instance - exec_program = EXEC_TEMPLATE.render( - indent=" ", is_profiler=True, problem_args=problem_args_template.render() - ) - op_func = ADAPTOR_FUNCTION_TEMPLATE.render( - instance=instance, - is_profiler=True, - func_name=name, - indent=" ", - exec_program=exec_program, - has_bias=has_bias, - ) - func_call = ADAPTER_CALL_TEMPLATE.render( - func_name=name, - instance=name, + benchmark_instance = BENCHMARK_INSTANCE_TEMPLATE.render( + indent=" ", + instance_name=instance_name, + gemm_op=gemm_op, + gemm_op_name=op_name, + func_name=f"benchmark_{instance_name_base}", sharedMemPerMultiprocessor="device_properties.sharedMemPerMultiprocessor", multiProcessorCount="device_properties.multiProcessorCount", - workspace="global_workspace", + workspace="global_workspace_", problem_count="problem_count", problem_sizes_device="problem_sizes_device.get()", ptr_A="ptr_A.get()", @@ -785,16 +858,58 @@ def gen_profiler( ldc="ldc.get()", ldd="ldd.get()", ) - code = common.PROFILER_TEMPLATE.render( - op_func=op_func, - args_parse=ARGS_PARSER_TEMPLATE.render(), - func_call=func_call, - name=name, - tensor_decl=TENSOR_DECL_TEMPLATE.render(name=name, has_bias=has_bias), - ) - common.add_profiler(file_pairs, workdir, op_type, op_name, code) + instances.append(instance) + benchmark_instances.append(benchmark_instance) + op_func = ADAPTOR_FUNCTION_TEMPLATE.render( + instances="\n".join(instances), + is_profiler=True, + func_name=instance_name_base, + indent=" ", + exec_program=exec_program, + has_bias=has_bias, + ) + func_call = ADAPTER_CALL_TEMPLATE.render( + is_profiler=True, + func_name=instance_name_base, + instance=instance_name_base, + sharedMemPerMultiprocessor="sharedMemPerMultiprocessor", + multiProcessorCount="multiProcessorCount", + workspace="global_workspace_", + problem_count="problem_count", + problem_sizes_device="problem_sizes_device", + ptr_A="ptr_A", + ptr_B="ptr_B", + ptr_C="ptr_C", + has_bias=has_bias, + ptr_bias="ptr_bias", + lda="lda", + ldb="ldb", + ldc="ldc", + ldd="ldd", + ) + tensor_decl = TENSOR_DECL_TEMPLATE.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + has_bias=has_bias, + ) + code = common.PROFILER_TEMPLATE.render( + is_group_gemm=True, + op_func=op_func, + has_bias=has_bias, + args_parse=ARGS_PARSER_TEMPLATE.render(), + function_name=f"{instance_name_base}_adapter", + func_call=func_call, + name=instance_name_base, + tensor_decl=tensor_decl, + benchmark_instances="\n".join(benchmark_instances), + elem_type=elem_type, + ) + # FIXME: remove file_pairs once we have make -j ready for building + # an entire graph + file_pairs = [] + common.add_profiler(file_pairs, workdir, op_type, profiler_filename, code) # build - common.build_profiler(file_pairs) + return common.build_profiler(file_pairs) def gen_function( @@ -804,7 +919,17 @@ def gen_function( problem_args_template, has_bias=False, ): - problem_args = problem_args_template.render() + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + problem_args = problem_args_template.render( + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, + ) func_name = func_attrs["name"] exec_path = func_attrs["exec_path"] op_instance = func_attrs["op_instance"] @@ -873,7 +998,7 @@ def gen_function( instance=fname, sharedMemPerMultiprocessor="sharedMemPerMultiprocessor", multiProcessorCount="multiProcessorCount", - workspace="global_workspace", + workspace="global_workspace_", problem_count=func_attrs["groups"], problem_sizes_device="problem_sizes_device", ptr_A="ptr_A", @@ -890,7 +1015,10 @@ def gen_function( exec_paths += exec_inst exec_program = EXEC_TEMPLATE.render( - indent=" ", is_profiler=False, problem_args=problem_args + indent=" ", + instance="GemmInstance", + is_profiler=False, + problem_args=problem_args, ) adapter_func = ADAPTOR_FUNCTION_TEMPLATE.render( func_name=func_name, exec_program=exec_program, has_bias=has_bias @@ -914,6 +1042,8 @@ def gen_function( instances=instance_decl, func_adapter=adapter_func, function_name=func_name, + elem_input_type=elem_input_type, + elem_output_type=elem_output_type, shape_function=shape_func, group_operands=group_operands, group_operand_dims=group_operand_dims, @@ -962,7 +1092,7 @@ def gen_function_call(func_attrs, ndims, has_bias=False, indent=" "): operand_dims.append("&" + cshape[1]._attrs["name"]) group_operands.append(operands) group_operand_dims.append(operand_dims) - device_args = f'reinterpret_cast(unique_workspace + {func_attrs["unique_workspace_offset"]})' + device_args = f'unique_workspace_ + {func_attrs["unique_workspace_offset"]}' return FUNC_CALL_TEMPLATE.render( func_name=func_attrs["name"], problem_count=func_attrs["groups"], diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py b/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py index 2b556fc83..c18ef3e5f 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py @@ -28,10 +28,10 @@ problem_count, threadblock_count, {ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, - ptr_A, - ptr_B, - ptr_bias, - ptr_C, + ({{elem_input_type}}**)(ptr_A), + ({{elem_input_type}}**)(ptr_B), + ({{elem_input_type}}**)(ptr_bias), + ({{elem_output_type}}**)ptr_C, lda, ldb, ldc, @@ -43,10 +43,16 @@ def gen_profiler( func_attrs, workdir, + profiler_filename, shape_template, ): - group_common.gen_profiler( - func_attrs, workdir, shape_template, PROBLEM_ARGS_TEMPLATE, has_bias=True + return group_common.gen_profiler( + func_attrs, + workdir, + profiler_filename, + shape_template, + PROBLEM_ARGS_TEMPLATE, + has_bias=True, ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py index 354039b40..83f0e2aa0 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py @@ -29,10 +29,10 @@ problem_count, threadblock_count, {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, - ptr_A, - ptr_B, - ptr_C, - ptr_C, + ({{elem_input_type}}**)(ptr_A), + ({{elem_input_type}}**)(ptr_B), + ({{elem_output_type}}**)(ptr_C), + ({{elem_output_type}}**)(ptr_C), lda, ldb, ldc, @@ -43,13 +43,13 @@ @registry.reg("cuda.group_gemm_rcr.config") def group_rcr_config(func_attrs, dtype="float16"): - common.make_fproc_f16(func_attrs, RCR) + common.make_fproc(func_attrs, RCR) @registry.reg("cuda.group_gemm_rcr.gen_profiler") -def gen_profiler(func_attrs, workdir, shape_template): - group_common.gen_profiler( - func_attrs, workdir, shape_template, PROBLEM_ARGS_TEMPLATE +def gen_profiler(func_attrs, workdir, profiler_filename, shape_template): + return group_common.gen_profiler( + func_attrs, workdir, profiler_filename, shape_template, PROBLEM_ARGS_TEMPLATE ) diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py index c292c3e1d..88c348d2e 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py @@ -27,8 +27,10 @@ def group_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.group_gemm_rcr_bias.gen_profiler") -def gen_profiler(func_attrs, workdir, shape_template): - group_common_bias.gen_profiler(func_attrs, workdir, shape_template) +def gen_profiler(func_attrs, workdir, profiler_filename, shape_template): + return group_common_bias.gen_profiler( + func_attrs, workdir, profiler_filename, shape_template + ) @registry.reg("cuda.group_gemm_rcr_bias.gen_function") diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py index 9345c26e4..fc43233da 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py @@ -27,8 +27,10 @@ def group_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.group_gemm_rcr_bias_relu.gen_profiler") -def gen_profiler(func_attrs, workdir, shape_template): - group_common_bias.gen_profiler(func_attrs, workdir, shape_template) +def gen_profiler(func_attrs, workdir, profiler_filename, shape_template): + return group_common_bias.gen_profiler( + func_attrs, workdir, profiler_filename, shape_template + ) @registry.reg("cuda.group_gemm_rcr_bias_relu.gen_function") diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py index e247bbe2a..bce93b575 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py @@ -27,8 +27,10 @@ def group_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.group_gemm_rcr_bias_sigmoid.gen_profiler") -def gen_profiler(func_attrs, workdir, shape_template): - group_common_bias.gen_profiler(func_attrs, workdir, shape_template) +def gen_profiler(func_attrs, workdir, profiler_filename, shape_template): + return group_common_bias.gen_profiler( + func_attrs, workdir, profiler_filename, shape_template + ) @registry.reg("cuda.group_gemm_rcr_bias_sigmoid.gen_function") diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py index 580a3b005..7d1741c52 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py @@ -43,32 +43,45 @@ def _get_problem_info(**kwargs): @registry.reg("cuda.perm021fc_ccr.config") def gemm_ccr_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.ColumnMajor, b_layout=cutlass_lib.library.LayoutType.ColumnMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = common.extract_config(fproc_f16) + func_attrs["op_instance"] = common.extract_config(fproc) @registry.reg("cuda.perm021fc_ccr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( a_dims=["B", "K", "M"], b_dims=["1", "N", "K"], c_dims=["B", "M", "N"] ) - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -82,8 +95,12 @@ def gen_function( exec_cond_template, dim_info_dict, ): - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) return bmm_common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py index b4f320de9..69712f30f 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py @@ -48,22 +48,27 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.perm021fc_ccr_bias.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( a_dims=["B", "K", "M"], b_dims=["1", "N", "K"], c_dims=["B", "M", "N"] ) - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common_bias.SRC_TEMPLATE, problem_args, args_parser, - bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", ) @@ -73,8 +78,12 @@ def gen_function( exec_cond_template, dim_info_dict, ): - mm_info = _get_problem_info(alpha_value=func_attrs.get("alpha", 1)) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + mm_info = _get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) output_ndims = len(func_attrs["output_accessors"][0].original_shapes) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py index 5631bf3ca..76ac6533b 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py @@ -18,11 +18,10 @@ """ from ... import registry -from ..gemm_universal import common - from . import ( bmm_common, bmm_permute_common, + common, common_bias, common_permute, perm021fc_ccr_bias, @@ -85,24 +84,34 @@ class Tensor3DPermute021BMM { @registry.reg("cuda.perm021fc_ccr_bias_permute.config") def config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common_permute.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.ColumnMajor, b_layout=cutlass_lib.library.LayoutType.ColumnMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], permute_layout=func_attrs["layout"], ) - func_attrs["op_instance"] = common_permute.extract_config(fproc_f16, func_attrs) + func_attrs["op_instance"] = common_permute.extract_config(fproc, func_attrs) @registry.reg("cuda.perm021fc_ccr_bias_permute.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): - return perm021fc_ccr_bias.gen_profiler(func_attrs, workdir, dim_info_dict) +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): + return perm021fc_ccr_bias.gen_profiler( + func_attrs, workdir, profiler_filename, dim_info_dict + ) @registry.reg("cuda.perm021fc_ccr_bias_permute.gen_function") @@ -112,9 +121,11 @@ def gen_function( dim_info_dict, ): mm_info = perm021fc_ccr_bias._get_problem_info( - alpha_value=func_attrs.get("alpha", 1) + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, ) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) return bmm_permute_common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py index 35a9ef77d..3d08f0291 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py @@ -44,33 +44,45 @@ def _get_problem_info(**kwargs): @registry.reg("cuda.perm021fc_crc.config") def gemm_crc_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.ColumnMajor, b_layout=cutlass_lib.library.LayoutType.RowMajor, c_layout=cutlass_lib.library.LayoutType.ColumnMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = common.extract_config(fproc_f16) + func_attrs["op_instance"] = common.extract_config(fproc) @registry.reg("cuda.perm021fc_crc.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( a_dims=["1", "K", "N"], b_dims=["B", "K", "M"], c_dims=["B", "M", "N"] ) problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( - mm_info=_get_problem_info(alpha_value=func_attrs.get("alpha", 1), beta_value=0) + mm_info=_get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + beta_value=0, + ), ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -85,7 +97,10 @@ def gen_function( dim_info_dict, ): problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( - mm_info=_get_problem_info(alpha_value=func_attrs.get("alpha", 1), beta_value=0) + mm_info=_get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + beta_value=0, + ), ) return bmm_common.gen_function( diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py index 187a0c6c1..3e6497c76 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py @@ -49,23 +49,26 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.perm021fc_crc_bias.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( a_dims=["1", "K", "N"], b_dims=["B", "K", "M"], c_dims=["B", "M", "N"] ) problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( - mm_info=_get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info=_get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ), ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common_bias.SRC_TEMPLATE, problem_args, args_parser, - bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", ) @@ -76,7 +79,9 @@ def gen_function( dim_info_dict, ): problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( - mm_info=_get_problem_info(alpha_value=func_attrs.get("alpha", 1)) + mm_info=_get_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ), ) input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py index fe0ffe9cd..c414816d8 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py @@ -17,6 +17,7 @@ C[m, b, n](row) = bmm(A[m, b, k](row), B[b, n, k](col)) """ from ... import registry +from ...backend_spec import CUDASpec from . import bmm_common, common # pylint: disable=C0103,C0415,W0613,C0301,R1705,R1703 @@ -43,12 +44,17 @@ def _get_default_problem_info(**kwargs): # Currently only has output Tensor Accessor support. def _get_strided_problem_info(func_attrs): + backend_spec = CUDASpec() + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + return bmm_common.Bmm_problem_info( alpha_value=func_attrs.get("alpha", 1), a_ptr="a_ptr", b_ptr="b_ptr", - bias_ptr="(c_ptr + output_offset)", - c_ptr="(c_ptr + output_offset)", + bias_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", + c_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", a_batch_stride="K", b_batch_stride="N * K", bias_batch_stride="output_batch_stride", @@ -94,32 +100,45 @@ def get_output_addr_calculator(func_attrs): @registry.reg("cuda.perm102_bmm_rcr.config") def gemm_rcr_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.RowMajor, b_layout=cutlass_lib.library.LayoutType.ColumnMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = common.extract_config(fproc_f16) + func_attrs["op_instance"] = common.extract_config(fproc) @registry.reg("cuda.perm102_bmm_rcr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( a_dims=["M", "B", "K"], b_dims=["B", "N", "K"], c_dims=["M", "B", "N"] ) - mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + mm_info = _get_default_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -136,7 +155,9 @@ def gen_function( bmm_problem_info = _get_strided_problem_info(func_attrs) # broadcasting is not supported - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) return bmm_common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py index 8c34ecd48..92afe0ca5 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py @@ -17,6 +17,7 @@ C[m, b, n](row) = bmm(A[m, b, k](row), B[b, n, k](col)) + bias[n]. """ from ... import registry +from ...backend_spec import CUDASpec from . import bmm_common, common, common_bias, perm102_bmm_rcr from .perm102_bmm_rcr import get_output_addr_calculator @@ -45,13 +46,18 @@ def _get_default_problem_info(**kwargs): # Currently only has output Tensor Accessor support. def _get_strided_problem_info(func_attrs): + backend_spec = CUDASpec() + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + return bmm_common.Bmm_problem_info( alpha_value=func_attrs.get("alpha", 1), beta_value=1, a_ptr="(a_ptr)", b_ptr="(b_ptr)", bias_ptr="(bias_ptr)", - c_ptr="(c_ptr + output_offset)", + c_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", a_batch_stride="K", b_batch_stride="N * K", bias_batch_stride="N", @@ -69,22 +75,27 @@ def gemm_rcr_config(func_attrs, dtype="float16"): @registry.reg("cuda.perm102_bmm_rcr_bias.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( a_dims=["M", "B", "K"], b_dims=["B", "N", "K"], c_dims=["M", "B", "N"] ) - mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + mm_info = _get_default_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common_bias.SRC_TEMPLATE, problem_args, args_parser, - bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", ) @@ -97,7 +108,9 @@ def gen_function( bmm_problem_info = _get_strided_problem_info(func_attrs) # broadcasting is not supported - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py index e4a3d7d1b..2f8d35522 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py @@ -17,6 +17,7 @@ C[m, b, n](row) = bmm(A[m, b, k](row), B[b, k, n](row)) """ from ... import registry +from ...backend_spec import CUDASpec from . import bmm_common, common from .perm102_bmm_rcr import get_output_addr_calculator @@ -44,12 +45,17 @@ def _get_default_problem_info(**kwargs): # Currently only has output Tensor Accessor support. def _get_strided_problem_info(func_attrs): + backend_spec = CUDASpec() + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) + return bmm_common.Bmm_problem_info( - alpha_value=func_attrs.get("alpha", 1), a_ptr="(a_ptr)", b_ptr="(b_ptr)", - bias_ptr="(c_ptr + output_offset)", - c_ptr="(c_ptr + output_offset)", + bias_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", + c_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", + alpha_value=func_attrs.get("alpha", 1), a_batch_stride="K", b_batch_stride="N * K", bias_batch_stride="output_batch_stride", @@ -63,32 +69,45 @@ def _get_strided_problem_info(func_attrs): @registry.reg("cuda.perm102_bmm_rrr.config") def gemm_rrr_config(func_attrs, dtype="float16"): - def fproc_f16(op): + def fproc(op): import cutlass_lib - return common.default_fproc_f16( + from ...backend_spec import CUDASpec + + backend_spec = CUDASpec() + elem_type = backend_spec.dtype_to_lib_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) + + return common.default_fproc( op=op, a_layout=cutlass_lib.library.LayoutType.RowMajor, b_layout=cutlass_lib.library.LayoutType.RowMajor, c_layout=cutlass_lib.library.LayoutType.RowMajor, + elem_type=elem_type, epiligue_name=func_attrs["epilogue"], ) - func_attrs["op_instance"] = common.extract_config(fproc_f16) + func_attrs["op_instance"] = common.extract_config(fproc) @registry.reg("cuda.perm102_bmm_rrr.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( a_dims=["M", "B", "K"], b_dims=["B", "K", "N"], c_dims=["M", "B", "N"] ) - mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + mm_info = _get_default_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common.SRC_TEMPLATE, problem_args, @@ -105,7 +124,9 @@ def gen_function( bmm_problem_info = _get_strided_problem_info(func_attrs) # broadcasting is not supported - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) return bmm_common.gen_function( func_attrs, diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py index f7435c071..e065d70c1 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py @@ -17,6 +17,7 @@ C[m, b, n](row) = bmm(A[m, b, k](row), B[b, k, n](row)) + bias[n] """ from ... import registry +from ...backend_spec import CUDASpec from . import bmm_common, common, common_bias, perm102_bmm_rrr from .perm102_bmm_rcr import get_output_addr_calculator @@ -45,13 +46,17 @@ def _get_default_problem_info(**kwargs): # Currently only has output Tensor Accessor support. def _get_strided_problem_info(func_attrs): + backend_spec = CUDASpec() + elem_output_type = backend_spec.dtype_to_lib_type( + func_attrs["outputs"][0]._attrs["dtype"] + ) return bmm_common.Bmm_problem_info( alpha_value=func_attrs.get("alpha", 1), beta_value=1, a_ptr="(a_ptr)", b_ptr="(b_ptr)", bias_ptr="(bias_ptr)", - c_ptr="(c_ptr + output_offset)", + c_ptr="(" + elem_output_type + "*)(c_ptr) + output_offset", a_batch_stride="K", b_batch_stride="N * K", bias_batch_stride="N", @@ -69,22 +74,27 @@ def gemm_rrr_config(func_attrs, dtype="float16"): @registry.reg("cuda.perm102_bmm_rrr_bias.gen_profiler") -def gen_profiler(func_attrs, workdir, dim_info_dict): +def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict): args_parser = bmm_common.ARGS_PARSER_TEMPLATE.render( a_dims=["M", "B", "K"], b_dims=["B", "K", "N"], c_dims=["M", "B", "N"] ) - mm_info = _get_default_problem_info(alpha_value=func_attrs.get("alpha", 1)) - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=mm_info) + mm_info = _get_default_problem_info( + alpha_value=func_attrs.get("alpha", 1), + ) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=mm_info, + ) - bmm_common.gen_profiler( + return bmm_common.gen_profiler( func_attrs, workdir, + profiler_filename, dim_info_dict, common_bias.SRC_TEMPLATE, problem_args, args_parser, - bias_ptr_arg="memory_pool->RequestHalfTensorByIdx(3)", + bias_ptr_arg="memory_pool->RequestTensorByIdx(3)", ) @@ -97,7 +107,9 @@ def gen_function( bmm_problem_info = _get_strided_problem_info(func_attrs) # broadcasting is not supported - problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(mm_info=bmm_problem_info) + problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render( + mm_info=bmm_problem_info, + ) input_ndims = len(func_attrs["input_accessors"][0].original_shapes) weight_ndims = len(func_attrs["input_accessors"][1].original_shapes) diff --git a/python/aitemplate/backend/cuda/groupnorm/groupnorm_common.py b/python/aitemplate/backend/cuda/groupnorm/groupnorm_common.py index 5b075783c..d1a48f28b 100644 --- a/python/aitemplate/backend/cuda/groupnorm/groupnorm_common.py +++ b/python/aitemplate/backend/cuda/groupnorm/groupnorm_common.py @@ -21,21 +21,19 @@ import jinja2 +from ...backend_spec import CUDASpec from ...target import Target -FUNC_CALL_FP16_PARAM_TEMPLATE = jinja2.Template( - "reinterpret_cast(&({{name}}->raw()))" -) - FUNC_SIGNATURE = jinja2.Template( """ -cudaError_t {{func_name}}(half* output, - half* input, - half* gamma, - half* beta, +cudaError_t {{func_name}}(void* output, + void* input, + void* gamma, + void* beta, int N, const float eps, const int max_smem_size, + void* workspace, cudaStream_t stream) """ ) @@ -51,7 +49,8 @@ {{indent}}{ {{indent}} {{func_name}}( {{indent}} {{output}}, {{input}}, {{gamma}}, {{beta}}, {{N}}, -{{indent}} {{eps}}, max_smem_size, stream /* default stream */ +{{indent}} {{eps}}, max_smem_size_, global_workspace_, +{{indent}} stream /* default stream */ {{indent}} ); {{indent}}} """ @@ -69,26 +68,32 @@ #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "logging.h" +#include +#include {{gamma_beta_const_defs}} namespace { +{{helper_libs}} + {{custom_libs}} } // namespace {{func_signature}} { - return invokeGroupNorm<{{FuseSwish}}, {{H}}, {{W}}, {{C}}, {{G}}>( - output, - input, - gamma, - beta, + + return invokeGroupNorm_{{elem_input_type}}<{{FuseSwish}}, {{H}}, {{W}}, {{C}}, {{G}}>( + static_cast<{{elem_input_type}}*>(output), + static_cast<{{elem_input_type}}*>(input), + static_cast<{{elem_input_type}}*>(gamma), + static_cast<{{elem_input_type}}*>(beta), N, eps, max_smem_size, + workspace, stream); } """ @@ -113,15 +118,15 @@ def get_input_names(func_attrs: Dict[str, Any]) -> List[str]: beta = inputs[idx] idx += 1 - input_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=x._attrs["name"]) + input_name = x._attrs["name"] if gamma is None: gamma_name = "nullptr" else: - gamma_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=gamma._attrs["name"]) + gamma_name = gamma._attrs["name"] if beta is None: beta_name = "nullptr" else: - beta_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render(name=beta._attrs["name"]) + beta_name = beta._attrs["name"] return (input_name, gamma_name, beta_name) @@ -135,11 +140,19 @@ def groupnorm_gen_function(func_attrs: Dict[str, Any]) -> str: C = input_shape[3].value() G = func_attrs["num_groups"] + backend_spec = CUDASpec() + elem_input_type = backend_spec.dtype_to_backend_type( + func_attrs["inputs"][0]._attrs["dtype"] + ) return FUNC_TEMPLATE.render( + helper_libs=Target.current().get_custom_libs( + os.path.dirname(__file__), "layer_norm.cuh" + ), custom_libs=Target.current().get_custom_libs( os.path.dirname(__file__), "groupnorm_kernel.cuh" ), func_signature=FUNC_SIGNATURE.render(func_name=func_attrs["name"]), + elem_input_type=elem_input_type, FuseSwish="true" if use_swish else "false", H=H, W=W, @@ -161,9 +174,7 @@ def groupnorm_gen_func_call(func_attrs: Dict[str, Any], indent=" ") -> str: func_attrs["inputs"] ), "expected at least 1 inputs but got {}".format(len(func_attrs["inputs"])) - output_name = FUNC_CALL_FP16_PARAM_TEMPLATE.render( - name=func_attrs["outputs"][0]._attrs["name"] - ) + output_name = func_attrs["outputs"][0]._attrs["name"] (input_name, gamma_name, beta_name) = get_input_names(func_attrs) input_shape = func_attrs["inputs"][0]._attrs["shape"] eps = func_attrs["eps"] diff --git a/python/aitemplate/backend/cuda/groupnorm/groupnorm_kernel.cuh b/python/aitemplate/backend/cuda/groupnorm/groupnorm_kernel.cuh index 6a235589c..2a22ed903 100644 --- a/python/aitemplate/backend/cuda/groupnorm/groupnorm_kernel.cuh +++ b/python/aitemplate/backend/cuda/groupnorm/groupnorm_kernel.cuh @@ -33,10 +33,30 @@ #define GROUP_NORM_CUDA_CHECK_LAUNCH() GROUP_NORM_CUDA_CHECK(cudaGetLastError()) #endif +__device__ half fast_tanh(half x) { + return half(cutlass::fast_tanh(float(x))); +} + __inline__ __device__ float sigmoid(float val) { return (cutlass::fast_tanh(val * 0.5f) + 1.0f) * 0.5f; } +__device__ half constant_half() { + uint16_t bits = 0x3800u; + return reinterpret_cast(bits); +} + +__device__ half one() { + uint16_t bits = 0x3c00u; + return reinterpret_cast(bits); +} + +__inline__ __device__ half hsigmoid(half a) { + half half_val = constant_half(); + half one_val = one(); + return __hmul((__hadd(fast_tanh(__hmul(a, half_val)), one_val)), half_val); +} + //////////////////////////////////////////////////////////////////////////////// // The Groupnorm implementation below is based on OneFlow's Layernorm // implementation at: @@ -447,7 +467,7 @@ __global__ __launch_bounds__(NUM_THREADS) void group_norm_smem( } template -cudaError_t invokeWelfordGroupNorm( +cudaError_t invokeWelfordGroupNorm_half( half* output, half* input, half* gamma, @@ -512,8 +532,322 @@ cudaError_t invokeWelfordGroupNorm( return cudaSuccess; } +template +struct AffineStore { + AffineStore( + DST* y, + int64_t row_size, + int64_t channel_size, + int64_t spatial_size, + const DST* gamma, + const DST* beta) + : y(y), + row_size(row_size), + channel_size(channel_size), + spatial_size(spatial_size), + gamma(gamma), + beta(beta) {} + + template + __device__ void store(const SRC* src, int64_t row, int64_t col) { + layer_norm::Pack y_pack; + const int64_t offset = row * row_size + col; + const int64_t packed_offset = offset / PackSize; + const int64_t gamma_beta_offset = (offset / spatial_size) % channel_size; + DST gamma_val = 1.0; + DST beta_val = 0.0; + if (affine) { + gamma_val = gamma[gamma_beta_offset]; + beta_val = beta[gamma_beta_offset]; + } + +#pragma unroll + for (int i = 0; i < PackSize; ++i) { + DST normalized_i = static_cast(src[i]); + if (affine) { + y_pack.elem[i] = normalized_i * gamma_val + beta_val; + } else { + // Direct Store. + y_pack.elem[i] = normalized_i; + } + if (FuseSwish) { + y_pack.elem[i] = y_pack.elem[i] * hsigmoid(y_pack.elem[i]); + } + } + *(reinterpret_cast*>(y) + + packed_offset) = y_pack.storage; + } + bool CanPackAs(size_t pack_size) { + return (spatial_size % pack_size) == 0; + } + DST* y; + int64_t row_size; + int64_t channel_size; + int64_t spatial_size; + const DST* gamma; + const DST* beta; +}; + +template +struct ScaleLoad { + ScaleLoad( + const SRC* src, + const SRC* gamma, + int64_t row_size, + int64_t channel_size, + int64_t spatial_size) + : src(src), + gamma(gamma), + row_size(row_size), + channel_size(channel_size), + spatial_size(spatial_size) {} + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { + layer_norm::Pack src_pack; + layer_norm::Pack gamma_pack; + + const int64_t offset = row * row_size + col; + const int64_t packed_offset = offset / PackSize; + const int64_t gamma_offset = (offset / spatial_size) % channel_size; + + src_pack.storage = + *(reinterpret_cast*>(src) + + packed_offset); + SRC gamma_val = static_cast(1.0); + if (affine) { + gamma_val = gamma[gamma_offset]; + } +#pragma unroll + for (int i = 0; i < PackSize; ++i) { + dst[i] = static_cast(src_pack.elem[i] * gamma_val); + } + } + bool CanPackAs(size_t pack_size) { + return (spatial_size % pack_size) == 0; + } + const SRC* src; + const SRC* gamma; + int64_t row_size; + int64_t channel_size; + int64_t spatial_size; +}; + +template +struct ChannelsLastStore { + ChannelsLastStore( + DST* y, + const DST* gamma, + const DST* beta, + int64_t spatial_size, + int64_t channel_size, + int64_t num_groups) + : y(y), + gamma(gamma), + beta(beta), + spatial_size(spatial_size), + c0(num_groups), + c1(channel_size / num_groups) {} + + template + __device__ void store(const SRC* src, int32_t row, int32_t col) { + layer_norm::Pack y_pack; + layer_norm::Pack gamma_pack; + layer_norm::Pack beta_pack; + int32_t spatial_idx; + int32_t c1_idx; + c1(spatial_idx, c1_idx, col); + int32_t batch_idx; + int32_t c0_idx; + c0(batch_idx, c0_idx, row); + const int32_t y_offset = + (batch_idx * c0.divisor * c1.divisor * spatial_size + + spatial_idx * c0.divisor * c1.divisor + c0_idx * c1.divisor + c1_idx) / + PackSize; + const int32_t gamma_beta_offset = (c0_idx * c1.divisor + c1_idx) / PackSize; + if (affine) { + gamma_pack.storage = *( + reinterpret_cast*>(gamma) + + gamma_beta_offset); + beta_pack.storage = + *(reinterpret_cast*>(beta) + + gamma_beta_offset); + } + +#pragma unroll + for (int i = 0; i < PackSize; ++i) { + DST normalized_i = static_cast(src[i]); + if (affine) { + y_pack.elem[i] = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]; + } else { + // Direct Store. + y_pack.elem[i] = normalized_i; + } + if (FuseSwish) { + y_pack.elem[i] = y_pack.elem[i] * hsigmoid(y_pack.elem[i]); + } + } + *(reinterpret_cast*>(y) + y_offset) = + y_pack.storage; + } + bool CanPackAs(size_t pack_size) { + return (c1.divisor % pack_size) == 0; + } + DST* y; + const DST* gamma; + const DST* beta; + int32_t spatial_size; + cutlass::FastDivmod c0; + cutlass::FastDivmod c1; +}; + +template +struct ChannelsLastLoad { + ChannelsLastLoad( + const SRC* src, + int64_t spatial_size, + int64_t channel_size, + int64_t num_groups) + : src(src), + spatial_size(spatial_size), + c0(num_groups), + c1(channel_size / num_groups) {} + template + __device__ void load(DST* dst, int32_t row, int32_t col) const { + int32_t spatial_idx; + int32_t c1_idx; + c1(spatial_idx, c1_idx, col); + int32_t batch_idx; + int32_t c0_idx; + c0(batch_idx, c0_idx, row); + layer_norm::Pack pack; + const int32_t offset = + (batch_idx * c0.divisor * c1.divisor * spatial_size + + spatial_idx * c0.divisor * c1.divisor + c0_idx * c1.divisor + c1_idx) / + N; + + pack.storage = + *(reinterpret_cast*>(src) + offset); +#pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = static_cast(pack.elem[i]); + } + } + bool CanPackAs(size_t pack_size) { + return (c1.divisor % pack_size) == 0; + } + const SRC* src; + int32_t spatial_size; + cutlass::FastDivmod c0; + cutlass::FastDivmod c1; +}; + +template +void GroupNormForwardGpu( + cudaStream_t stream, + const int64_t num_instances, + const int64_t norm_size, + const int64_t channel_size, + const int64_t spatial_size, + const double epsilon, + const T* x_ptr, + const T* gamma_ptr, + const T* beta_ptr, + T* y_ptr, + ComputeType* mean, + ComputeType* inv_variance, + bool channels_first) { + // using ComputeType = typename layer_norm::DefaultComputeType::type; + if (channels_first) { + layer_norm::DirectLoad load(x_ptr, norm_size); + AffineStore store( + y_ptr, norm_size, channel_size, spatial_size, gamma_ptr, beta_ptr); + + layer_norm::DispatchLayerNorm( + stream, + load, + store, + num_instances, + norm_size, + epsilon, + mean, + inv_variance); + } else { + ChannelsLastLoad load( + x_ptr, + spatial_size, + channel_size, + channel_size / (norm_size / spatial_size)); + ChannelsLastStore store( + y_ptr, + gamma_ptr, + beta_ptr, + spatial_size, + channel_size, + channel_size / (norm_size / spatial_size)); + + layer_norm::DispatchLayerNorm( + stream, + load, + store, + num_instances, + norm_size, + epsilon, + mean, + inv_variance); + } +} + +template +void DispatchGroupNormForwardGpu( + cudaStream_t stream, + const int64_t num_instances, + const int64_t norm_size, + const int64_t channel_size, + const int64_t spatial_size, + const double epsilon, + const T* x_ptr, + const T* gamma_ptr, + const T* beta_ptr, + T* y_ptr, + T2* mean, + T2* inv_variance, + bool channels_first) { + using ComputeType = typename layer_norm::DefaultComputeType::type; + if (gamma_ptr != nullptr && beta_ptr != nullptr) { + GroupNormForwardGpu( + stream, + num_instances, + norm_size, + channel_size, + spatial_size, + epsilon, + x_ptr, + gamma_ptr, + beta_ptr, + y_ptr, + mean, + inv_variance, + channels_first); + } else { + GroupNormForwardGpu( + stream, + num_instances, + norm_size, + channel_size, + spatial_size, + epsilon, + x_ptr, + gamma_ptr, + beta_ptr, + y_ptr, + mean, + inv_variance, + channels_first); + } +} + template -cudaError_t invokeGroupNorm( +cudaError_t invokeGroupNorm_half( half* output, half* input, half* gamma, @@ -521,11 +855,19 @@ cudaError_t invokeGroupNorm( int N, const float eps, const int max_smem_size, + void* workspace, cudaStream_t stream) { constexpr auto C_G = C / G; constexpr auto C_G_2 = C_G / 2; constexpr int ILP = 8; + const int64_t num_instances = N * G; + const int64_t norm_size = H * W * C / G; + const int64_t spatial_size = H * W; + const int64_t channel_size = C; + const double epsilon = eps; + bool channels_first = false; + // Use a little big more shared_memory to reduce occupancy and boost perf. constexpr int MEM_BANK_CONFLICT = 1; @@ -543,14 +885,42 @@ cudaError_t invokeGroupNorm( smem)); constexpr int num_threads = std::min(1024, H / ILP * W * C_G_2); - - dim3 block(num_threads); - group_norm_smem - <<>>( - input, output, gamma, beta, N, eps); + if constexpr (num_threads > 0) { + dim3 block(num_threads); + group_norm_smem + <<>>( + input, output, gamma, beta, N, eps); + } else { + DispatchGroupNormForwardGpu( + stream, + num_instances, + norm_size, + channel_size, + spatial_size, + epsilon, + static_cast(input), + static_cast(gamma), + static_cast(beta), + static_cast(output), + reinterpret_cast(workspace), + reinterpret_cast(workspace + sizeof(float) * num_instances), + channels_first); + } } else { - return invokeWelfordGroupNorm( - output, input, gamma, beta, N, eps, stream); + DispatchGroupNormForwardGpu( + stream, + num_instances, + norm_size, + channel_size, + spatial_size, + epsilon, + static_cast(input), + static_cast(gamma), + static_cast(beta), + static_cast(output), + reinterpret_cast(workspace), + reinterpret_cast(workspace + sizeof(float) * num_instances), + channels_first); } // GROUP_NORM_CUDA_CHECK_LAUNCH(); diff --git a/python/aitemplate/backend/cuda/groupnorm/layer_norm.cuh b/python/aitemplate/backend/cuda/groupnorm/layer_norm.cuh new file mode 100644 index 000000000..baa1981b3 --- /dev/null +++ b/python/aitemplate/backend/cuda/groupnorm/layer_norm.cuh @@ -0,0 +1,2404 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +namespace layer_norm { + +constexpr int kWarpSize = 32; + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return max(a, b); + } +}; + +template < + template + class ReductionOp, + typename T, + int thread_group_width = kWarpSize> +__inline__ __device__ T WarpAllReduce(T val) { + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + val = ReductionOp()( + val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width)); + } + return val; +} + +template