55import torch
66
77import flashinfer
8+ from flashinfer .autotuner import autotune
89from flashinfer .fused_moe import (
910 WeightLayout ,
1011 trtllm_fp4_block_scale_moe ,
@@ -186,6 +187,14 @@ def parse_moe_args(line, parser):
186187 choices = ["swiglu" , "geglu" ],
187188 help = "Type of gated activation function: swiglu | geglu." ,
188189 )
190+ parser .add_argument (
191+ "--autotune" ,
192+ action = "store_true" ,
193+ default = False ,
194+ help = (
195+ "Enable autotuner warmup for supported routines (trtllm_fp4_block_scale_moe and cutlass_fused_moe)."
196+ ),
197+ )
189198
190199 # CUTLASS fused MoE specific
191200 parser .add_argument (
@@ -604,10 +613,11 @@ def testTrtllmFp4BlockScaleMoe(args):
604613 hidden_states_fp4 = hidden_states_fp4_bytes .view (torch .uint8 ).reshape (
605614 hidden_states .shape [0 ], hidden_states .shape [1 ] // 2
606615 )
616+ # Hidden-states scale for FP4 must be 2D: [num_tokens, hidden_size // 16]
607617 hidden_states_scale_linear_fp4 = hidden_states_scale_fp4_bytes .view (
608618 torch .float8_e4m3fn
609- ). reshape ( - 1 )
610- # Ensure expected vector size (16 elements per hidden value for NvFP4)
619+ )
620+ # Ensure expected shape (16 elements per hidden value for NvFP4)
611621 expected_scale_elems = (num_tokens * hidden_size ) // 16
612622 if hidden_states_scale_linear_fp4 .numel () != expected_scale_elems :
613623 if args .verbose >= 1 :
@@ -617,6 +627,9 @@ def testTrtllmFp4BlockScaleMoe(args):
617627 hidden_states_scale_linear_fp4 = torch .ones (
618628 expected_scale_elems , device = device , dtype = torch .float8_e4m3fn
619629 )
630+ hidden_states_scale_linear_fp4 = hidden_states_scale_linear_fp4 .reshape (
631+ num_tokens , hidden_size // 16
632+ )
620633
621634 # Prepare weights for kernel
622635 # For FP4 weights, keep them as uint8 (packed format) - don't convert to float8_e4m3fn
@@ -691,6 +704,22 @@ def run_fp4_moe():
691704 do_finalize = True ,
692705 )
693706
707+ backend = "trtllm"
708+
709+ # Optional autotune warmup (supported for FP4 TRTLlm fused MoE)
710+ if getattr (args , "autotune" , False ):
711+ warmup_iters = (
712+ args .dry_run_iters if args .dry_run_iters and args .dry_run_iters > 0 else 10
713+ )
714+ backend = "trtllm_autotune"
715+ if args .verbose >= 1 :
716+ print (
717+ f"[INFO] Autotune warmup for FP4 block scale MoE: { warmup_iters } iters"
718+ )
719+ with autotune (True ):
720+ for _ in range (warmup_iters ):
721+ run_fp4_moe ()
722+
694723 # Benchmark timing
695724 if is_cuda_graph_compatible :
696725 times = bench_gpu_time_with_cudagraph (
@@ -734,7 +763,6 @@ def run_fp4_moe():
734763 routing_logits_dtype = routing_logits .dtype ,
735764 )
736765
737- backend = "trtllm"
738766 print_perf_metrics (backend , median_time , std_time , tflops , tb_per_sec )
739767
740768 res = []
@@ -1011,6 +1039,20 @@ def run_cutlass():
10111039 else :
10121040 raise ValueError (f"Unknown cutlass_variant: { variant } " )
10131041
1042+ backend = "cutlass"
1043+
1044+ # Optional autotune warmup (supported for CUTLASS fused MoE)
1045+ if getattr (args , "autotune" , False ):
1046+ warmup_iters = (
1047+ args .dry_run_iters if args .dry_run_iters and args .dry_run_iters > 0 else 10
1048+ )
1049+ backend = "cutlass_autotune"
1050+ if args .verbose >= 1 :
1051+ print (f"[INFO] Autotune warmup for CUTLASS fused MoE: { warmup_iters } iters" )
1052+ with autotune (True ):
1053+ for _ in range (warmup_iters ):
1054+ run_cutlass ()
1055+
10141056 # Measure
10151057 if is_cuda_graph_compatible :
10161058 times = bench_gpu_time_with_cudagraph (
@@ -1064,7 +1106,6 @@ def run_cutlass():
10641106 active_experts = int (selected_experts .unique ().numel ()),
10651107 )
10661108
1067- backend = "cutlass"
10681109 print_perf_metrics (backend , median_time , std_time , tflops , tb_per_sec )
10691110
10701111 res = []
0 commit comments