Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Rock/Tuning/RockTuning.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ struct ParamEntry {
struct TuningParamSet {
llvm::SetVector<RockTuningParamAttrInterface> tuningRange;
KernelType primaryOpType;
// The tuning kind that was actually used (may differ from requested kind,
// e.g. Greedy falls back to Exhaustive for non-accel).
TuningParamSetKind effectiveKind;
};

struct TuningParamSpaceSettings {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ createTunableParamSpace(ModuleOp mod, TuningParamSetKind kind,
rock::TuningParamSpaceSettings &settings) {
struct TuningParamSet *newSpace;
newSpace = new TuningParamSet();
newSpace->effectiveKind = kind;

// create range and heuristic
WalkResult findPrimary =
Expand All @@ -974,6 +975,7 @@ createTunableParamSpace(ModuleOp mod, TuningParamSetKind kind,
// greedy is not implemented for non-accel
if (!archInfo.isAccel(op) && kind == TuningParamSetKind::Greedy) {
kind = TuningParamSetKind::Exhaustive;
newSpace->effectiveKind = kind;
llvm::errs() << "Greedy tuning not implemented for non-accel, using "
"Exhaustive instead\n";
}
Expand Down
62 changes: 54 additions & 8 deletions mlir/tools/rocmlir-tuning-driver/ConcurrentQueue.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- ConcurrentQueue.h - Simple MPMC queue --------------------*- C++ -*-===//
//===- ConcurrentQueue.h - Rate-adaptive MPMC queue -------------*- C++ -*-===//
//
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -11,7 +11,9 @@

#include "llvm/Support/Compiler.h"

#include <algorithm>
#include <atomic>
#include <cassert>
#include <condition_variable>
#include <mutex>
#include <queue>
Expand All @@ -21,26 +23,39 @@ namespace rocmlir::tuningdriver {
template <typename T>
class ConcurrentQueue {
public:
// If maxCapacity is 0, the queue is unbounded
explicit ConcurrentQueue(size_t maxCapacity = 0) : maxCapacity(maxCapacity) {}

template <typename U>
bool push(U &&item) {
if (LLVM_UNLIKELY(done.load(std::memory_order_relaxed)))
return false; // Early exit if terminated

{
std::lock_guard<std::mutex> lock(mtx);
std::unique_lock<std::mutex> lock(mtx);

if (maxCapacity > 0) {
cvNotFull.wait(lock, [this] {
return queue.size() < currentCapacity ||
done.load(std::memory_order_relaxed);
});
}

if (LLVM_UNLIKELY(done.load(std::memory_order_relaxed)))
return false; // Double-check after acquiring the lock
return false;

queue.emplace(std::forward<U>(item));
}

cv.notify_one();
cvNotEmpty.notify_one();
return true;
}

bool pop(T &item) {
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [this] {

bool starved = queue.empty();
cvNotEmpty.wait(lock, [this] {
return !queue.empty() || done.load(std::memory_order_relaxed);
});

Expand All @@ -49,20 +64,51 @@ class ConcurrentQueue {

item = std::move(queue.front());
queue.pop();

if (maxCapacity > 0) {
if (starved) {
// If the queue was empty, increase the capacity
currentCapacity = std::min(currentCapacity + 1, maxCapacity);
consecutiveFed = 0;
} else {
++consecutiveFed;
if (consecutiveFed >= fedShrinkThreshold) {
// Decrease the capacity if the queue has been fed for a while
currentCapacity =
std::max(currentCapacity / 2, static_cast<size_t>(1));
consecutiveFed = 0;
}
}
}

lock.unlock();

cvNotFull.notify_one();
return true;
}

void terminate() {
done.store(true, std::memory_order_relaxed);
cv.notify_all();
{
std::lock_guard<std::mutex> lock(mtx);
done.store(true, std::memory_order_relaxed);
}
cvNotEmpty.notify_all();
cvNotFull.notify_all();
}

bool isTerminated() const { return done.load(std::memory_order_relaxed); }

private:
static constexpr size_t fedShrinkThreshold = 4;

const size_t maxCapacity;
size_t currentCapacity{maxCapacity};
size_t consecutiveFed{0};

std::queue<T> queue;
std::mutex mtx;
std::condition_variable cv;
std::condition_variable cvNotEmpty;
std::condition_variable cvNotFull;
std::atomic<bool> done{false};
};

Expand Down
Loading