Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions water/include/water/Dialect/Wave/IR/WaveInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ using EmitErrorFn = llvm::function_ref<mlir::InFlightDiagnostic()>;

class WaveTensorType;

/// Get the hyperparameters from an ancestor operation.
/// Returns nullptr if no hyperparameters are found.
WaveHyperparameterAttr getHyperparameters(mlir::Operation *op);

//-----------------------------------------------------------------------------
// HasWaveIndexMapping trait
//-----------------------------------------------------------------------------
Expand Down
4 changes: 0 additions & 4 deletions water/include/water/Dialect/Wave/Transforms/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@

namespace wave {

/// Get the hyperparameters from an ancestor operation.
/// Returns nullptr if no hyperparameters are found.
WaveHyperparameterAttr getHyperparameters(mlir::Operation *op);

// Populates `constraints` with a mapping from an operation with a Wave
// constraints attribute attached to that attribute.
llvm::LogicalResult collectWaveConstraints(
Expand Down
3 changes: 1 addition & 2 deletions water/lib/Dialect/Wave/IR/WaveDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
#include "mlir/IR/Dialect.h"

#include "water/Dialect/Wave/IR/WaveDialect.cpp.inc"
#include "water/Dialect/Wave/IR/WaveInterfaces.h"
#include "water/Dialect/Wave/IR/WaveUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/LogicalResult.h"
#include <algorithm>
#include <optional>

using namespace mlir;
Expand Down
45 changes: 45 additions & 0 deletions water/lib/Dialect/Wave/IR/WaveInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@

using namespace mlir;

//-----------------------------------------------------------------------------
// getHyperparameters
//-----------------------------------------------------------------------------

wave::WaveHyperparameterAttr wave::getHyperparameters(Operation *op) {
for (Operation *current = op; current; current = current->getParentOp()) {
if (auto hyperparams = current->getAttrOfType<WaveHyperparameterAttr>(
WaveDialect::kHyperparameterAttrName))
return hyperparams;
}
return nullptr;
}

//-----------------------------------------------------------------------------
// Index attribute verification
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -86,6 +99,38 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) {
}
}

// For ops with the index attribute, verify that each index expression has at
// most one dimension whose step evaluates to a static value different from 1
// (with hyperparameters substituted). Structural checks stay in op verifiers.
wave::WaveHyperparameterAttr hyperparams = wave::getHyperparameters(op);
for (DictionaryAttr dictAttr : dicts) {
int nonUnitCount = 0;
for (const NamedAttribute &named : dictAttr) {
auto mapping = dyn_cast<wave::WaveIndexMappingAttr>(named.getValue());
if (!mapping || !mapping.getStep())
continue;

std::optional<SmallVector<int64_t>> stepValues =
wave::evaluateMapWithHyperparams(mapping.getStep(),
mapping.getSymbols(), hyperparams);
if (!stepValues || stepValues->size() != 1)
continue;

int64_t step = (*stepValues)[0];
if (step == 1 || step == ShapedType::kDynamic)
continue;

if (++nonUnitCount > 1) {
InFlightDiagnostic diag =
op->emitOpError() << "'" << WaveDialect::kIndexWaveExprListAttrName
<< "' has more than one entry with non-unit step";
diag.attachNote() << "second non-unit step dimension: "
<< named.getName();
return failure();
}
}
}

// When the operation implements WaveInferIndexExprsOpInterface, the index
// attribute length must match the number of values from
// getIndexExprValuesAndDescriptions. Otherwise, default to the number of op
Expand Down
15 changes: 4 additions & 11 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1572,9 +1572,8 @@ ReadOp::propagateBackward(MutableArrayRef<wave::WaveTensorType> operandTypes,

LogicalResult ReadOp::finalizeTypeInference() { return success(); }

// Check the well-formedness of the index attribute (must have at most one
// non-unit dimension) and its correspondence with the explicit elements per
// thread, if provided, and with the number of elements in the vector type.
// Check the correspondence of the index attribute with the explicit elements
// per thread, if provided, and with the number of elements in the vector type.
static LogicalResult
verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr,
std::optional<int64_t> elementsPerThread,
Expand Down Expand Up @@ -1605,7 +1604,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr,
if (!indexDict)
return success();

wave::WaveHyperparameterAttr hyper = wave::WaveHyperparameterAttr();
wave::WaveHyperparameterAttr hyper = nullptr;
for (Operation *cur = op; cur != nullptr && !hyper;
cur = cur->getParentOp()) {
hyper = cur->getAttrOfType<wave::WaveHyperparameterAttr>(
Expand All @@ -1621,7 +1620,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr,
getUncollapsedVectorShape(tensorType.getShape(), indexDict, hyper);
int64_t nonUnit = 1;
bool hadDynamic = false;
for (auto [i, size] : llvm::enumerate(shape)) {
for (int64_t size : shape) {
if (ShapedType::isDynamic(size)) {
hadDynamic = true;
continue;
Expand All @@ -1632,13 +1631,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr,
}
if (nonUnit == 1) {
nonUnit = size;
continue;
}

InFlightDiagnostic diag =
op->emitError() << "'index' has more than one entry with non-unit step";
diag.attachNote() << "second non-unit step dimension: " << i;
return diag;
}

// If there were unevaluated steps, they may end up matching later on.
Expand Down
17 changes: 5 additions & 12 deletions water/lib/Dialect/Wave/Transforms/InferTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,18 +829,11 @@ class ElementsPerThreadForwardAnalysis

// Elements per thread may be 1 if _all_ dimensions have a unit step,
// otherwise it should be the one non-unit step.
// TODO(#1013): this logic can be reused in the verifier.
if (!elementsPerThread.has_value()) {
elementsPerThread = (*stepValues)[0];
} else if (*elementsPerThread == 1) {
elementsPerThread = (*stepValues)[0];
} else if (stepValue != 1) {
// TODO(#1013): turn this into an assertion when the verifier is
// implemented.
op->emitError() << "expected only one non-unit index step, found "
<< (*stepValues)[0] << " and " << *elementsPerThread
<< " (missing verifier)";
return WalkResult::interrupt();
assert((!elementsPerThread.has_value() || *elementsPerThread == 1 ||
stepValue == 1) &&
"expected only one non-unit index step");
if (!elementsPerThread.has_value() || *elementsPerThread == 1) {
elementsPerThread = stepValue;
}
}

Expand Down
9 changes: 0 additions & 9 deletions water/lib/Dialect/Wave/Transforms/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,6 @@

using namespace mlir;

wave::WaveHyperparameterAttr wave::getHyperparameters(Operation *op) {
for (Operation *current = op; current; current = current->getParentOp()) {
if (auto hyperparams = current->getAttrOfType<WaveHyperparameterAttr>(
WaveDialect::kHyperparameterAttrName))
return hyperparams;
}
return nullptr;
}

llvm::LogicalResult wave::collectWaveConstraints(
Operation *top, llvm::DenseMap<Operation *, Attribute> &constraints) {
auto *waveDialect = top->getContext()->getLoadedDialect<wave::WaveDialect>();
Expand Down
22 changes: 19 additions & 3 deletions water/test/Dialect/Wave/ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,11 @@ func.func @bounds_wrong_rank(%mem: !wave.tensor<[@N] of f32>) {

// -----

func.func @read_index_multi_step(%mem: !wave.tensor<[@M, @N] of f32>) {
func.func @read_index_multi_step(%mem: !wave.tensor<[@M, @N] of f32>) attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 1, N = 1}>
} {
// expected-error @below {{'index' has more than one entry with non-unit step}}
// expected-note @below {{second non-unit step dimension: 1}}
// expected-note @below {{second non-unit step dimension: "N"}}
wave.read %mem index [{
M : <[#wave.index_symbol<T0>] -> (T0, 2, 1)>,
N : <[#wave.index_symbol<T1>] -> (T1, 2, 1)>
Expand Down Expand Up @@ -643,7 +645,7 @@ func.func @read_index_multi_step_eval(%mem: !wave.tensor<[@M, @N] of f32>) attri
wave.hyperparameters = #wave.hyperparameters<{X = 1, Y = 1, M = 100, N = 200}>
} {
// expected-error @below {{'index' has more than one entry with non-unit step}}
// expected-note @below {{second non-unit step dimension: 1}}
// expected-note @below {{second non-unit step dimension: "N"}}
wave.read %mem index [{
M : <[#wave.index_symbol<T0>, #wave.symbol<"X">] -> (T0, 2 * X, 1)>,
N : <[#wave.index_symbol<T1>, #wave.symbol<"X">, #wave.symbol<"Y">] -> (T1, X + Y, 1)>
Expand All @@ -653,6 +655,20 @@ func.func @read_index_multi_step_eval(%mem: !wave.tensor<[@M, @N] of f32>) attri

// -----

func.func @write_index_multi_step_eval(%val: !wave.tensor<[@M, @N] of f32, <register>>, %mem: !wave.tensor<[@M, @N] of f32, <global>>) attributes {
wave.hyperparameters = #wave.hyperparameters<{X = 1, Y = 1, M = 100, N = 200}>
} {
// expected-error @below {{'index' has more than one entry with non-unit step}}
// expected-note @below {{second non-unit step dimension: "N"}}
wave.write %val, %mem index [{
M : <[#wave.index_symbol<T0>, #wave.symbol<"X">] -> (T0, 2 * X, 1)>,
N : <[#wave.index_symbol<T1>, #wave.symbol<"X">, #wave.symbol<"Y">] -> (T1, X + Y, 1)>
}] : !wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@M, @N] of f32, <global>>
return
}

// -----

func.func @extract_invalid_position_rank(%src: !wave.tensor<[@M, @N] of f32>) {
// expected-error @below {{position must contain exactly one expression, but got 2}}
wave.extract %src[#wave.expr_list<[] -> (0, 1)>] : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M] of f32>
Expand Down
15 changes: 0 additions & 15 deletions water/test/Dialect/Wave/propagate-elements-per-thread.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,6 @@ normalform.module [#wave.normal_form<full_types>] {

// -----

// Two dimensions with non-unit steps; pass must report "expected only one non-unit".
// Use only register (write has its own verifier for multi-step index).
normalform.module [#wave.normal_form<full_types>] {
func.func @index_multi_non_unit_step(%mem: !wave.tensor<[@M, @N] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} {
%cst = arith.constant 0.0 : f16
// expected-error @below {{expected only one non-unit index step}}
%reg = wave.register %cst index [{M : <[] -> (<NULL>, 4, <NULL>)>, N : <[] -> (<NULL>, 8, <NULL>)>}] : !wave.tensor<[@M, @N] of f16, <register>>
wave.write %reg, %mem index [{M : <[] -> (<NULL>, 4, <NULL>)>, N : <[] -> (<NULL>, 1, <NULL>)>}]
: !wave.tensor<[@M, @N] of f16, <register>>, !wave.tensor<[@M, @N] of f16, <global>>
return
}
}

// -----

// Index missing dimension N for result type [M, N]; pass must report missing dimensions.
normalform.module [#wave.normal_form<full_types>] {
func.func @index_missing_dimension(%mem: !wave.tensor<[@M, @N] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} {
Expand Down
Loading