diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc index 50df2c06671fb..72b31206a1ad8 100644 --- a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc +++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc @@ -150,16 +150,19 @@ static void MakeTTGIR(mlir::OpPassManager* pm, static void MakeLLIR(mlir::OpPassManager* pm, const stream_executor::RocmComputeCapability& rocm_cc, int num_stages) { - pm->addPass(mlir::createTritonAMDGPUUpdateAsyncWaitCount()); + pm->addPass(mlir::createTritonAMDGPUUpdateAsyncWaitCount({rocm_cc.gfx_version()})); pm->addPass( mlir::triton::AMD::createConvertWarpPipelinePass(rocm_cc.gfx_version())); pm->addPass(mlir::createSCFToControlFlowPass()); + pm->addPass(mlir::createInlinerPass()); pm->addPass(mlir::createConvertIndexToLLVMPass()); pm->addPass(mt::gpu::createAllocateSharedMemory()); pm->addPass(mt::gpu::createTritonGPUGlobalScratchAllocationPass()); pm->addPass(mt::gpu::createTritonGPUGlobalScratchAllocationPass()); pm->addPass( mt::createConvertTritonAMDGPUToLLVMPass(rocm_cc.gfx_version(), true)); + pm->addPass( + mlir::triton::AMD::createTritonAMDGPUConvertWarpSpecializeToLLVMPass(rocm_cc.gfx_version())); pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createCSEPass()); // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.