Skip to content

Fail to run gemm_bench with Problem size (2,2,2) (4,4,4) (8,8,8) #4

@LeiWang1999

Description

@LeiWang1999

Hi there, I'm currently benchmarking gemm perfomance of amos on tensorcore, I modified the mapping_gemm_tensorcore as below:

import tvm
import os
from tvm import auto_tensorize as at
import argparse


def gemm(M, N, K, in_dtype, out_dtype):
    A = tvm.te.placeholder([M, K], dtype=in_dtype, name="A")
    B = tvm.te.placeholder([K, N], dtype=in_dtype, name="B")

    rk = tvm.te.reduce_axis([0, K], name="k")
    C = tvm.te.compute(
        [M, N], lambda i, j: tvm.te.sum((A[i, rk] * B[rk, j]).astype(out_dtype), axis=rk), name="C"
    )
    return [A, B, C]


def mapping_tensorcore(
    M,
    N,
    K,
    layer,
    in_dtype,
    out_dtype,
    simple_mode=True,
    trials=-1,
    verbose=False,
    use_perf_model=False,
    perf_model_ratio=0.6,
):
    A, B, Gemm = gemm(M, N, K, in_dtype, out_dtype)
    target_dag = at.compute_dag_from_tensors([Gemm])
    target = "cuda"

    log_dir = "gemm-%s-%s-layer-%s" % (in_dtype, out_dtype, layer)
    log_file = "gemm-%s-%s-layer-%s.log" % (in_dtype, out_dtype, layer)

    measure_opt = at.MeasureOptions(target=target, timeout=100, number=200, min_repeat_ms=500)

    if simple_mode:
        trials = 1000 if trials < 0 else trials
        result = at.auto_tensorize(
            target_dag, target, log_file, measure_opt, trials=trials, verbose=verbose
        )
        if not result.defined():
            print("Can't do tensorize.")
            return
        schedule_gen = result.sch_gen
        schedule_app = result.sch_app

        # load from file
        schedule_gen.load_from_file(log_file, clear=True)
        entry = schedule_gen.get_best_entry()
        # we store 1/time_cost in file
        params, value = entry.record, 1 / entry.value
        print(value)
        print(params.to_json())
    else:
        trials = 4000 if trials < 0 else trials
        result = at.auto_tensorize_v4(
            target_dag,
            target,
            log_file,
            measure_opt,
            schedule_log_dir=log_dir,
            trials=trials,
            search_group_size=5,
            transform_dump=verbose,
            enable_perf_model=use_perf_model,
            perf_percentage=perf_model_ratio,
        )
        if not result.defined():
            print("Can't do tensorize.")
            return
        schedule_gen = result.sch_gen
        schedule_app = result.sch_app

        # we store 1/time_cost in file
        params, value = result.params, result.perf
        print(value)
        print(params.to_json())

    cost = at.evaluate_params(schedule_app, params, measure_opt, dump=verbose)
    print("Cost of %s is %f ms" % (log_dir, cost))
    return cost


shapes = [(16, 512, 128), (1024, 16, 256), (256, 1024, 256), (512, 256, 16), (1024, 1024, 1024)]

supported_dtypes = set(
    [
        ("float16", "float16"),
        ("float16", "float32"),
        ("bfloat16", "float32"),
        ("float32", "float32"),
        ("float64", "float64"),
        ("int4", "int32"),
        ("int8", "int32"),
    ]
)

example_text = """
 example:
    python mapping_gemm_tensorcore.py --in_dtype float16 --out_dtype float16 --begin 0 --num 1 --trials 20
    python mapping_gemm_tensorcore.py --in_dtype float16 --out_dtype float32 --begin 0 --num 1 --trials 20
    python mapping_gemm_tensorcore.py --in_dtype float32 --out_dtype float32 --begin 0 --num 1 --trials 20
    python mapping_gemm_tensorcore.py --in_dtype float16 --out_dtype float16 --begin 0 --num 1 --trials 400 --simple_mode 0
"""

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="base_maker",
        description="template maker",
        epilog=example_text,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--in_dtype",
        type=str,
        choices=["float16", "float32", "float64", "bfloat16", "int4", "int8"],
        default="float16",
    )
    parser.add_argument(
        "--out_dtype",
        type=str,
        choices=["float16", "float32", "float64", "int32"],
        default="float16",
    )
    parser.add_argument("--begin", type=int, choices=list(range(len(shapes))), default=0)
    parser.add_argument(
        "--num", type=int, choices=list(range(1, len(shapes) + 1)), default=len(shapes)
    )
    parser.add_argument("--simple_mode", type=int, default=1, choices=[0, 1])
    parser.add_argument("--trials", type=int, default=-1)
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--use_perf_model", action="store_true")
    parser.add_argument("--perf_model_ratio", type=float, default=0.6)

    args = parser.parse_args()
    assert 0 < args.perf_model_ratio <= 1.0
    if args.use_perf_model:
        assert args.simple_mode == 0, "Performance model is only supported without simple_mode"
    beg = args.begin
    num = args.num
    print(args.simple_mode)
    assert (
        args.in_dtype,
        args.out_dtype,
    ) in supported_dtypes, (
        f"The desired dtype pair {(args.in_dtype, args.out_dtype)} is not supported by Tensor Core."
    )
    costs = []
    for i, shape in enumerate(shapes[beg : beg + num]):
        (M, N, K) = shape
        print("\n\nProblem size:")
        print(M, N, K)
        layer_name = f"({M}, {N}, {K})"
        try:
            cost = mapping_tensorcore(
                M,
                N,
                K,
                layer_name,
                args.in_dtype,
                args.out_dtype,
                simple_mode=args.simple_mode,
                trials=args.trials,
                verbose=args.verbose,
                use_perf_model=args.use_perf_model,
                perf_model_ratio=args.perf_model_ratio,
            )
            costs.append(cost)
        except Exception as e:
            print("Fail to run\n", str(e))
            costs.append(float("inf"))
    for cost in costs:
        print(cost)

just with shaped customized, but this bench throw a Fail to run error.

0


Problem size:
2 2 2
Possible matchings:
0 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:16x16x16)
1 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:32x8x16)
2 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:8x32x16)
3 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:16x16x16)
4 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:32x8x16)
5 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:8x32x16)
6 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:16x16x16)
7 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:32x8x16)
8 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:8x32x16)
9 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:16x16x16)
10 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:32x8x16)
11 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:8x32x16)
Logging to devnull...
Totally 1 different mappings for this matching
Logging to devnull...
Totally 1 different mappings for this matching
Catch an infeasible mapping:
{"vmap": [[1], -1]}
Fail to run



Problem size:
4 4 4
Possible matchings:
0 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:16x16x16)
1 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:32x8x16)
2 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:8x32x16)
3 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:16x16x16)
4 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:32x8x16)
5 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:8x32x16)
6 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:16x16x16)
7 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:32x8x16)
8 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:8x32x16)
9 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:16x16x16)
10 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:32x8x16)
11 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:8x32x16)
Logging to devnull...
Totally 1 different mappings for this matching
Logging to devnull...
Totally 1 different mappings for this matching
Catch an infeasible mapping:
{"vmap": [[1], -1]}
Fail to run



Problem size:
8 8 8
Possible matchings:
0 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:16x16x16)
1 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:32x8x16)
2 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:nnn, shape:8x32x16)
3 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:16x16x16)
4 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:32x8x16)
5 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ntn, shape:8x32x16)
6 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:16x16x16)
7 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:32x8x16)
8 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:tnn, shape:8x32x16)
9 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:16x16x16)
10 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:32x8x16)
11 : MatchResult(hw_abs_dag:wmma_fp16_fp16, compute:ttn, shape:8x32x16)
Logging to devnull...
Totally 1 different mappings for this matching
Logging to devnull...
Totally 1 different mappings for this matching
Catch an infeasible mapping:
{"vmap": [[1], -1]}
Fail to run

but (16, 16, 16) can do well, any suggestions?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions