Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,32 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);

/*!
* \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate
* Hopper tensor core intrinsics
* \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key
* "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for
* initialization, loading operand A, loading operand B, tensor core computation, storing the
* result. The value of the map should be names of tensor intrinsics, must be registered via
* TensorIntrin.register(...) beforehand
* \param structure The tiling structure. Recommended:
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \param use_software_pipeline Whether use the software pipeline.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingTensorCoreHopper(
Array<Map<String, String>> intrin_groups, String structure,
Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);

/*!
* \brief Extension of MultiLevelTiling for backends with wide vectors.
* The loop over the innermost spatial axis of the output buffer is always vectorized with the
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .multi_level_tiling import (
MultiLevelTiling,
MultiLevelTilingTensorCore,
MultiLevelTilingTensorCoreHopper,
MultiLevelTilingWideVector,
MultiLevelTilingWithIntrin,
ReuseType,
Expand Down
55 changes: 55 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,61 @@ def __init__(
)


@register_object("meta_schedule.ScheduleRuleMultiLevelTilingTensorCoreHopper")
class ScheduleRuleMultiLevelTilingTensorCoreHopper(ScheduleRule):
"""Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate tensor
core intrinsics.

Parameters
----------
intrin_groups : List[Mapping[str, str]]
A list of groups of tensor core intrinsics. The map should contains key "init", "load_a",
"load_b", "compute", "store", which represent the tensor intrin for initialization,
loading operand A, loading operand B, tensor core computation, storing the result.
The value of the map should be names of tensor intrinsics, must be registerd via
TensorIntrin.register(...) beforehand
structure : str
The tiling structure. Recommended:
- 'SSSRRSRS' on GPU
tile_bind : Optional[List[str]]
For each level of tiles, which thread axis it is bound to. Recommended:
- [blockIdx.y, vthread.x, threadIdx.y] on GPU
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_lens : Optional[List[int]]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
use_software_pipeline : bool
Whether to use the software pipeline.
"""

def __init__(
self,
intrin_groups: List[Mapping[str, str]],
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
use_software_pipeline: bool = False,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingTensorCoreHopper, # type: ignore # pylint: disable=no-member
intrin_groups,
structure,
tile_binds,
max_innermost_factor,
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
use_software_pipeline,
)

@register_object("meta_schedule.MultiLevelTilingWideVector")
class MultiLevelTilingWideVector(ScheduleRule):
"""Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost
Expand Down
65 changes: 65 additions & 0 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,71 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:

return mma_store_desc, mma_store_desc

def get_wgmma_intrin_group(
load_scope: Literal["shared", "shared.dyn"],
store_scope: Literal["global", "shared", "shared.dyn"],
in_dtype: str,
out_dtype: str,
trans_b: bool,
) -> Dict[str, str]:
"""Get a group of intrinsics for wgmma tensor core with the given configurations

Parameters
----------
load_scope : Literal["shared", "shared.dyn"]
The memory scope of the input buffer.

store_scope : Literal["global", "shared", "shared.dyn"]
The memory scope of the result buffer.

in_dtype : str
The input data type.

out_dtype : str
The output data dtype.

trans_b : bool
Whether the input matrix B is transposed.

Returns
-------
ret : Dict[str, str]
A group of tensor intrinsics.
"""
assert load_scope in ["shared", "shared.dyn"]
assert store_scope in ["global", "shared", "shared.dyn"]
assert in_dtype in ["float16", "int8"]
assert out_dtype in ["float16", "float32", "int32"]

shape = "16x16x16"
in_dtype = "f16" if in_dtype == "float16" else "s8"
out_dtype = "f16" if out_dtype == "float16" else "f32" if out_dtype == "float32" else "s32"
# convert "shared.dyn" to "shared_dyn"
load_scope = load_scope.replace(".", "_")
store_scope = store_scope.replace(".", "_")
trans_a = ""
trans_b = "_trans" if trans_b else ""

# e.g. wgmma_load_16x16x16_f16_a_shared
load_a_intrin = f"wgmma_load_{shape}_{in_dtype}_a{trans_a}_{load_scope}"
# e.g. wgmma_load_16x16x16_f16_b_trans_shared_dyn
load_b_intrin = f"wgmma_load_{shape}_{in_dtype}_b{trans_b}_{load_scope}"
# e.g. wgmma_sync_16x16x16_f16f16f32_trans
compute_intrin = f"wgmma_sync_{shape}_{in_dtype}{in_dtype}{out_dtype}{trans_b}"
# e.g. wgmma_fill_16x16x16_f16
init_intrin = f"wgmma_fill_{shape}_{out_dtype}"
# e.g. wgmma_store_16x16x16_f16_shared_dyn
store_intrin = f"wgmma_store_{shape}_{out_dtype}_{store_scope}"

return {
"init": init_intrin,
"load_a": load_a_intrin,
"load_b": load_b_intrin,
"compute": compute_intrin,
"store": store_intrin,
}



TensorIntrin.register("mma_init_m16n8k8_f16", *get_mma_init_intrin(16, 8, 8, "float16"))
TensorIntrin.register("mma_init_m16n8k8_f32", *get_mma_init_intrin(16, 8, 8, "float32"))
Expand Down
Loading