Skip to content

Commit be2f159

Browse files
committed
Add tuned FP8 GEMM kernels and configurations for Grok-2 on MI355X
1 parent dc9aa00 commit be2f159

272 files changed

Lines changed: 41960 additions & 0 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv.backup

Lines changed: 1587 additions & 0 deletions
Large diffs are not rendered by default.

aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv.backup_20251117_171633

Lines changed: 14087 additions & 0 deletions
Large diffs are not rendered by default.

aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv.backup_20251117_171638

Lines changed: 14087 additions & 0 deletions
Large diffs are not rendered by default.

aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv.pre_tp4

Lines changed: 675 additions & 0 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#include "gemm_a8w8_bpreshuffle_common.cuh"
5+
6+
template <typename DDataType, typename EDataType>
7+
torch::Tensor
8+
a8w8_bpreshuffle_128x16x32x128_16x16_16x16_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
9+
torch::Tensor &XQ,
10+
torch::Tensor &WQ,
11+
torch::Tensor &x_scale,
12+
torch::Tensor &w_scale,
13+
torch::Tensor &Y
14+
)
15+
{
16+
// The smallest kernel we have available. Works well for memory bound shapes.
17+
18+
// Check if this input needs to be padded.
19+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
20+
int N = WQ.size(0);
21+
int K = WQ.size(1);
22+
bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % (128) != 0);
23+
if (pad)
24+
{
25+
// pad
26+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
27+
DDataType, EDataType,
28+
128,
29+
16, 32, 128,
30+
16, 16,
31+
16, 16,
32+
1, 1,
33+
S<8, 16, 1>,
34+
S<8, 16, 1>,
35+
1,
36+
1,
37+
S<1, 16, 1, 8>,
38+
S<4, 4, 1>,
39+
ck::BlockGemmPipelineScheduler::Intrawave,
40+
ck::BlockGemmPipelineVersion::v2,
41+
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
42+
// Run kernel instance.
43+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
44+
45+
// pad
46+
}
47+
else
48+
{
49+
// no pad
50+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
51+
DDataType, EDataType,
52+
128,
53+
16, 32, 128,
54+
16, 16,
55+
16, 16,
56+
1, 1,
57+
S<8, 16, 1>,
58+
S<8, 16, 1>,
59+
1,
60+
1,
61+
S<1, 16, 1, 8>,
62+
S<4, 4, 1>,
63+
ck::BlockGemmPipelineScheduler::Intrawave,
64+
ck::BlockGemmPipelineVersion::v2,
65+
ck::tensor_operation::device::GemmSpecialization::Default>;
66+
// Run kernel instance.
67+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
68+
69+
// no pad
70+
}
71+
}
72+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#include "gemm_a8w8_bpreshuffle_common.cuh"
5+
6+
template <typename DDataType, typename EDataType>
7+
torch::Tensor
8+
a8w8_bpreshuffle_128x16x32x512_16x16_16x16_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
9+
torch::Tensor &XQ,
10+
torch::Tensor &WQ,
11+
torch::Tensor &x_scale,
12+
torch::Tensor &w_scale,
13+
torch::Tensor &Y
14+
)
15+
{
16+
// The smallest kernel we have available. Works well for memory bound shapes.
17+
18+
// Check if this input needs to be padded.
19+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
20+
int N = WQ.size(0);
21+
int K = WQ.size(1);
22+
bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % (512) != 0);
23+
if (pad)
24+
{
25+
// pad
26+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
27+
DDataType, EDataType,
28+
128,
29+
16, 32, 512,
30+
16, 16,
31+
16, 16,
32+
1, 1,
33+
S<32, 4, 1>,
34+
S<32, 4, 1>,
35+
1,
36+
1,
37+
S<1, 16, 1, 8>,
38+
S<4, 4, 1>,
39+
ck::BlockGemmPipelineScheduler::Intrawave,
40+
ck::BlockGemmPipelineVersion::v2,
41+
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
42+
// Run kernel instance.
43+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
44+
45+
// pad
46+
}
47+
else
48+
{
49+
// no pad
50+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
51+
DDataType, EDataType,
52+
128,
53+
16, 32, 512,
54+
16, 16,
55+
16, 16,
56+
1, 1,
57+
S<32, 4, 1>,
58+
S<32, 4, 1>,
59+
1,
60+
1,
61+
S<1, 16, 1, 8>,
62+
S<4, 4, 1>,
63+
ck::BlockGemmPipelineScheduler::Intrawave,
64+
ck::BlockGemmPipelineVersion::v2,
65+
ck::tensor_operation::device::GemmSpecialization::Default>;
66+
// Run kernel instance.
67+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
68+
69+
// no pad
70+
}
71+
}
72+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#include "gemm_a8w8_bpreshuffle_common.cuh"
5+
6+
template <typename DDataType, typename EDataType>
7+
torch::Tensor
8+
a8w8_bpreshuffle_128x32x16x512_16x16_16x16_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v1(
9+
torch::Tensor &XQ,
10+
torch::Tensor &WQ,
11+
torch::Tensor &x_scale,
12+
torch::Tensor &w_scale,
13+
torch::Tensor &Y
14+
)
15+
{
16+
// The smallest kernel we have available. Works well for memory bound shapes.
17+
18+
// Check if this input needs to be padded.
19+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
20+
int N = WQ.size(0);
21+
int K = WQ.size(1);
22+
bool pad = (M % 32 != 0) || (N % 16 != 0) || (K % (512) != 0);
23+
if (pad)
24+
{
25+
// pad
26+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
27+
DDataType, EDataType,
28+
128,
29+
32, 16, 512,
30+
16, 16,
31+
16, 16,
32+
1, 1,
33+
S<32, 4, 1>,
34+
S<32, 4, 1>,
35+
1,
36+
1,
37+
S<1, 32, 1, 4>,
38+
S<4, 4, 1>,
39+
ck::BlockGemmPipelineScheduler::Intrawave,
40+
ck::BlockGemmPipelineVersion::v1,
41+
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
42+
// Run kernel instance.
43+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
44+
45+
// pad
46+
}
47+
else
48+
{
49+
// no pad
50+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
51+
DDataType, EDataType,
52+
128,
53+
32, 16, 512,
54+
16, 16,
55+
16, 16,
56+
1, 1,
57+
S<32, 4, 1>,
58+
S<32, 4, 1>,
59+
1,
60+
1,
61+
S<1, 32, 1, 4>,
62+
S<4, 4, 1>,
63+
ck::BlockGemmPipelineScheduler::Intrawave,
64+
ck::BlockGemmPipelineVersion::v1,
65+
ck::tensor_operation::device::GemmSpecialization::Default>;
66+
// Run kernel instance.
67+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
68+
69+
// no pad
70+
}
71+
}
72+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#include "gemm_a8w8_bpreshuffle_common.cuh"
5+
6+
template <typename DDataType, typename EDataType>
7+
torch::Tensor
8+
a8w8_bpreshuffle_128x32x16x512_16x16_16x16_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2(
9+
torch::Tensor &XQ,
10+
torch::Tensor &WQ,
11+
torch::Tensor &x_scale,
12+
torch::Tensor &w_scale,
13+
torch::Tensor &Y
14+
)
15+
{
16+
// The smallest kernel we have available. Works well for memory bound shapes.
17+
18+
// Check if this input needs to be padded.
19+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
20+
int N = WQ.size(0);
21+
int K = WQ.size(1);
22+
bool pad = (M % 32 != 0) || (N % 16 != 0) || (K % (512) != 0);
23+
if (pad)
24+
{
25+
// pad
26+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
27+
DDataType, EDataType,
28+
128,
29+
32, 16, 512,
30+
16, 16,
31+
16, 16,
32+
1, 1,
33+
S<32, 4, 1>,
34+
S<32, 4, 1>,
35+
1,
36+
1,
37+
S<1, 32, 1, 4>,
38+
S<4, 4, 1>,
39+
ck::BlockGemmPipelineScheduler::Intrawave,
40+
ck::BlockGemmPipelineVersion::v2,
41+
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
42+
// Run kernel instance.
43+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
44+
45+
// pad
46+
}
47+
else
48+
{
49+
// no pad
50+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
51+
DDataType, EDataType,
52+
128,
53+
32, 16, 512,
54+
16, 16,
55+
16, 16,
56+
1, 1,
57+
S<32, 4, 1>,
58+
S<32, 4, 1>,
59+
1,
60+
1,
61+
S<1, 32, 1, 4>,
62+
S<4, 4, 1>,
63+
ck::BlockGemmPipelineScheduler::Intrawave,
64+
ck::BlockGemmPipelineVersion::v2,
65+
ck::tensor_operation::device::GemmSpecialization::Default>;
66+
// Run kernel instance.
67+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
68+
69+
// no pad
70+
}
71+
}
72+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#include "gemm_a8w8_bpreshuffle_common.cuh"
5+
6+
template <typename DDataType, typename EDataType>
7+
torch::Tensor
8+
a8w8_bpreshuffle_256x112x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v3(
9+
torch::Tensor &XQ,
10+
torch::Tensor &WQ,
11+
torch::Tensor &x_scale,
12+
torch::Tensor &w_scale,
13+
torch::Tensor &Y
14+
)
15+
{
16+
// The smallest kernel we have available. Works well for memory bound shapes.
17+
18+
// Check if this input needs to be padded.
19+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
20+
int N = WQ.size(0);
21+
int K = WQ.size(1);
22+
bool pad = (M % 112 != 0) || (N % 128 != 0) || (K % (256) != 0);
23+
if (pad)
24+
{
25+
// pad
26+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
27+
DDataType, EDataType,
28+
256,
29+
112, 128, 256,
30+
16, 16,
31+
16, 16,
32+
7, 2,
33+
S<16, 16, 1>,
34+
S<16, 16, 1>,
35+
1,
36+
2,
37+
S<1, 16, 1, 16>,
38+
S<8, 8, 1>,
39+
ck::BlockGemmPipelineScheduler::Intrawave,
40+
ck::BlockGemmPipelineVersion::v3,
41+
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
42+
// Run kernel instance.
43+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
44+
45+
// pad
46+
}
47+
else
48+
{
49+
// no pad
50+
using DeviceGemmInstance = DeviceGemmHelperF8Flatmm<
51+
DDataType, EDataType,
52+
256,
53+
112, 128, 256,
54+
16, 16,
55+
16, 16,
56+
7, 2,
57+
S<16, 16, 1>,
58+
S<16, 16, 1>,
59+
1,
60+
2,
61+
S<1, 16, 1, 16>,
62+
S<8, 8, 1>,
63+
ck::BlockGemmPipelineScheduler::Intrawave,
64+
ck::BlockGemmPipelineVersion::v3,
65+
ck::tensor_operation::device::GemmSpecialization::Default>;
66+
// Run kernel instance.
67+
return gemm_a8w8_bpreshuffle_impl<DDataType, EDataType, DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
68+
69+
// no pad
70+
}
71+
}
72+

0 commit comments

Comments
 (0)