Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions water/lib/Dialect/Wave/IR/WaveInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,34 +99,57 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) {
}
}

// For ops with the index attribute, verify that each index expression has at
// most one dimension whose step evaluates to a static value different from 1
// (with hyperparameters substituted). Structural checks stay in op verifiers.
// For ops with the index attribute, verify that (1) each index expression has
// at most one dimension whose step evaluates to a static value different from
// 1 (with hyperparameters substituted), and (2) when step or stride can be
// evaluated to a concrete value, that value is strictly positive. Be
// defensive because we may not have verified anything but the basic
// well-formedness yet, e.g., the op verifier checking for single-result
// affine expressions in mappings did not run yet.
wave::WaveHyperparameterAttr hyperparams = wave::getHyperparameters(op);
for (DictionaryAttr dictAttr : dicts) {
int nonUnitCount = 0;
for (const NamedAttribute &named : dictAttr) {
auto mapping = dyn_cast<wave::WaveIndexMappingAttr>(named.getValue());
if (!mapping || !mapping.getStep())
if (!mapping)
continue;

std::optional<SmallVector<int64_t>> stepValues =
wave::evaluateMapWithHyperparams(mapping.getStep(),
mapping.getSymbols(), hyperparams);
if (!stepValues || stepValues->size() != 1)
continue;

int64_t step = (*stepValues)[0];
if (step == 1 || step == ShapedType::kDynamic)
continue;
if (AffineMap stepMap = mapping.getStep()) {
std::optional<SmallVector<int64_t>> stepValues =
wave::evaluateMapWithHyperparams(stepMap, mapping.getSymbols(),
hyperparams);
if (stepValues && stepValues->size() == 1) {
int64_t step = (*stepValues)[0];
if (step != ShapedType::kDynamic && step <= 0) {
return op->emitOpError()
<< "step in index expression must be strictly positive, got "
<< step << " for dimension " << named.getName();
}
if (step != 1 && step != ShapedType::kDynamic && ++nonUnitCount > 1) {
InFlightDiagnostic diag =
op->emitOpError()
<< "'" << WaveDialect::kIndexWaveExprListAttrName
<< "' has more than one entry with non-unit step";
diag.attachNote()
<< "second non-unit step dimension: " << named.getName();
return failure();
}
}
}

if (++nonUnitCount > 1) {
InFlightDiagnostic diag =
op->emitOpError() << "'" << WaveDialect::kIndexWaveExprListAttrName
<< "' has more than one entry with non-unit step";
diag.attachNote() << "second non-unit step dimension: "
<< named.getName();
return failure();
if (AffineMap strideMap = mapping.getStride()) {
std::optional<SmallVector<int64_t>> strideValues =
wave::evaluateMapWithHyperparams(strideMap, mapping.getSymbols(),
hyperparams);
if (strideValues && strideValues->size() == 1) {
int64_t stride = (*strideValues)[0];
if (stride != ShapedType::kDynamic && stride <= 0) {
return op->emitOpError()
<< "stride in index expression must be strictly positive, "
"got "
<< stride << " for dimension " << named.getName();
}
}
}
}
}
Expand Down
10 changes: 3 additions & 7 deletions water/lib/Dialect/Wave/Transforms/InferTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,14 +818,10 @@ class ElementsPerThreadForwardAnalysis
wave::evaluateMapWithHyperparams(step, symbols, init.hyperparams);
if (!stepValues)
continue;
// TODO(#1012): turn this into an assertion when the verifier is
// implemented.

int64_t stepValue = (*stepValues)[0];
if (stepValue <= 0) {
op->emitError() << "expected positive step in index expressions "
"(missing verifier)";
return WalkResult::interrupt();
}
assert(stepValue > 0 &&
"expected positive step in index expressions");

// Elements per thread may be 1 if _all_ dimensions have a unit step,
// otherwise it should be the one non-unit step.
Expand Down
40 changes: 40 additions & 0 deletions water/test/Dialect/Wave/ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,46 @@ func.func @write_index_multi_step_eval(%val: !wave.tensor<[@M, @N] of f32, <regi

// -----

func.func @read_index_zero_step(%mem: !wave.tensor<[@M] of f32>) attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 128}>
} {
// expected-error @below {{step in index expression must be strictly positive, got 0 for dimension "M"}}
wave.read %mem index [{M : <[] -> (0, 0, 1)>}] : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32, <register>>
return
}

// -----

func.func @read_index_negative_step(%mem: !wave.tensor<[@M] of f32>) attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 128}>
} {
// expected-error @below {{step in index expression must be strictly positive, got -1 for dimension "M"}}
wave.read %mem index [{M : <[] -> (0, -1, 1)>}] : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32, <register>>
return
}

// -----

func.func @read_index_zero_stride(%mem: !wave.tensor<[@M] of f32>) attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 128}>
} {
// expected-error @below {{stride in index expression must be strictly positive, got 0 for dimension "M"}}
wave.read %mem index [{M : <[] -> (0, 1, 0)>}] : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32, <register>>
return
}

// -----

func.func @read_index_negative_stride(%mem: !wave.tensor<[@M] of f32>) attributes {
wave.hyperparameters = #wave.hyperparameters<{M = 128}>
} {
// expected-error @below {{stride in index expression must be strictly positive, got -1 for dimension "M"}}
wave.read %mem index [{M : <[] -> (0, 1, -1)>}] : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32, <register>>
return
}

// -----

func.func @extract_invalid_position_rank(%src: !wave.tensor<[@M, @N] of f32>) {
// expected-error @below {{position must contain exactly one expression, but got 2}}
wave.extract %src[#wave.expr_list<[] -> (0, 1)>] : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M] of f32>
Expand Down
14 changes: 0 additions & 14 deletions water/test/Dialect/Wave/propagate-elements-per-thread.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,6 @@ normalform.module [#wave.normal_form<full_types>] {

// -----

// Step is zero; pass must report "expected positive step".
normalform.module [#wave.normal_form<full_types>] {
func.func @index_step_zero(%mem: !wave.tensor<[@M] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} {
%cst = arith.constant 0.0 : f16
// expected-error @below {{expected positive step in index expressions}}
%reg = wave.register %cst index [{M : <[] -> (<NULL>, 0, <NULL>)>}] : !wave.tensor<[@M] of f16, <register>>
wave.write %reg, %mem index [{M : <[] -> (<NULL>, 0, <NULL>)>}]
: !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <global>>
return
}
}

// -----

// Index missing dimension N for result type [M, N]; pass must report missing dimensions.
normalform.module [#wave.normal_form<full_types>] {
func.func @index_missing_dimension(%mem: !wave.tensor<[@M, @N] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} {
Expand Down