Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions mlir/lib/Dialect/Rock/Transforms/RockPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Rock/IR/GetRockInfo.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/Passes.h"
#include "mlir/Dialect/Rock/Transforms/RockMultibuffer.h"
Expand Down Expand Up @@ -471,6 +472,74 @@ DagType pruneGraph(const DagType &dag) {
return prunedGraph;
}

// Determine if the backward barrier can be skipped for single-wave kernels.
//
// For scheduleVersion 1 (Default), the loop structure is:
// GlobalLoad -> DSWrite -> (fwd barrier) -> DSRead + MFMA
//
// For scheduleVersion 3 (DirectToLDSDefault), GlobalLoad writes directly to
// LDS, so the loop structure is logically:
// GlobalLoad (to LDS) -> (fwd barrier) -> DSRead + MFMA
//
// In both cases, the forward barrier ensures LDS writes (explicit DSWrite or
// DirectToLDS GlobalLoad) complete before DSReads start. For the
// loop-carried dependency (backward barrier), we need to ensure DSReads from
// iteration i finish before LDS writes from iteration i+1.
//
// When blockSize <= waveSize (single wave), this is guaranteed because
// GPU issues instructions in order within a wave - once DSReads have been
// issued, they have read the data from the buffers, so LDS writes can proceed
// without an explicit barrier.
bool canSkipBackwardBarrierForOneWave(func::FuncOp func, scf::ForOp forOp) {
// Check if this is a single-wave kernel
auto maybeBlockSize = rock::getBlockSize(func);
if (failed(maybeBlockSize))
return false;

int64_t blockSize = maybeBlockSize->getInt();

// Check if arch attribute exists before calling getArchValue which
// triggers llvm_unreachable if arch is missing
if (!func->hasAttr("arch") && !func->hasAttr("mhal.arch"))
return false;

StringAttr arch = rock::getArchValue(func);


int64_t waveSize = rock::lookupArchInfo(arch).waveSize;
bool isOneWave = (blockSize <= waveSize);
if (!isOneWave)
return false;

// For nested loops, it may require more analysis. For now, only support
// single loop.
int forOpCount = 0;
func.walk([&](scf::ForOp) { ++forOpCount; });
if (forOpCount != 1)
return false;

// Find the scheduleVersion from ThreadwiseGemmAccelOp within the loop.
// The scheduleVersion is stored in the params attribute of the op.
std::optional<int64_t> scheduleVersion;
forOp.walk([&](rock::ThreadwiseGemmAccelOp gemmOp) {
rock::RockAccelTuningParamAttrInterface params = gemmOp.getParams();
scheduleVersion = params.getScheduleVersion();
});

if (!scheduleVersion.has_value())
return false;

// Check if the schedule version supports skipping the backward barrier.
// Only scheduleVersion 1 (Default) and 3 (DirectToLDSDefault)
// have the loop structure that allows skipping the backward barrier.
bool canSkip = (*scheduleVersion == 1 || *scheduleVersion == 3);

LLVM_DEBUG(DBGS() << "canSkipBackwardBarrierForOneWave: isOneWave="
<< isOneWave << ", scheduleVersion=" << *scheduleVersion
<< ", canSkip=" << canSkip << "\n");
return canSkip;
}

// Utility function to place an empty stage before or after another `stage`. The
// empty stage will contain an `lds_barrier` if `isBarrier` is set to true
rock::StageOp placeEmptyStage(IRRewriter &rewriter, Location loc,
Expand All @@ -493,8 +562,8 @@ rock::StageOp placeEmptyStage(IRRewriter &rewriter, Location loc,
// initiation interval twice as big and pipeline as usual. This function
// takes also care to update the initiation interval, so that the caller
// does not have to know how `placeBarrier` internally works.
void placeBarriers(IRRewriter &rewriter, Location loc, scf::ForOp forOp,
ArrayRef<rock::StageOp> stages,
void placeBarriers(IRRewriter &rewriter, Location loc, func::FuncOp func,
scf::ForOp forOp, ArrayRef<rock::StageOp> stages,
SetVector<rock::GpuAllocOp> &allocs,
SmallVector<rock::StageOp> &extendedStages,
int64_t &initiationInterval, int64_t numIterations) {
Expand All @@ -503,8 +572,9 @@ void placeBarriers(IRRewriter &rewriter, Location loc, scf::ForOp forOp,
dag = pruneGraph(dag);

// If there is a loop, we probably need a backward barrier, i.e.,
// an LDS barrier that takes the loop dependency into account
const bool addBackwardBarrier = numIterations > 1;
// an LDS barrier that takes the loop dependency into account.
bool canSkipBackwardBarrier = canSkipBackwardBarrierForOneWave(func, forOp);
const bool addBackwardBarrier = numIterations > 1 && !canSkipBackwardBarrier;

DenseMap<rock::StageOp, int> timeSlotMap;
int timeSlot = 0;
Expand Down Expand Up @@ -768,8 +838,8 @@ void RockPipeline::runOnOperation() {
SmallVector<rock::StageOp> extendedStages;
// use "multiAllocs" to place LDS barriers, no need to explicitly place
// barriers for registers or globals
placeBarriers(rewriter, loc, forOp, stages, multiAllocs, extendedStages,
ii, numIterations);
placeBarriers(rewriter, loc, func, forOp, stages, multiAllocs,
extendedStages, ii, numIterations);
ScheduleType schedule;
// use all "resources" to generate dependency graph and generate schedule
createSchedule(extendedStages, resources, ii, schedule,
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Rock/rock-pipeline-early-exit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// COUNT-COUNT-1: rock.lds_barrier

module {
func.func @pipeline_loop_in_scf_if(%arg0: memref<128xf16>, %arg1: memref<128xf16>, %arg2: memref<128xf16>, %arg3: i32) attributes {block_size = 64 : i32, grid_size = 1 : i32, kernel} {
func.func @pipeline_loop_in_scf_if(%arg0: memref<128xf16>, %arg1: memref<128xf16>, %arg2: memref<128xf16>, %arg3: i32) attributes {arch = "amdgcn-amd-amdhsa:gfx90a", block_size = 64 : i32, grid_size = 1 : i32, kernel} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand Down
Loading