Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _codeql_detected_source_root
56 changes: 56 additions & 0 deletions water/lib/Dialect/Wave/IR/WaveInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,39 @@ using namespace mlir;
// Index attribute verification
//-----------------------------------------------------------------------------

// Helper function to collect all symbols used in the operation's operands and
// results.
static void collectSymbolsFromTypes(Operation *op,
llvm::SmallVectorImpl<StringRef> &symbols) {
llvm::StringSet<> seenSymbols;

// Collect symbols from operand types.
for (Value operand : op->getOperands()) {
auto tensorType = llvm::dyn_cast<wave::WaveTensorType>(operand.getType());
if (!tensorType || !tensorType.getFullySpecified())
continue;

for (wave::WaveSymbolAttr symbol : tensorType.getShape()) {
StringRef name = symbol.getName();
if (seenSymbols.insert(name).second)
symbols.push_back(name);
}
}

// Collect symbols from result types.
for (Value result : op->getResults()) {
auto tensorType = llvm::dyn_cast<wave::WaveTensorType>(result.getType());
if (!tensorType || !tensorType.getFullySpecified())
continue;

for (wave::WaveSymbolAttr symbol : tensorType.getShape()) {
StringRef name = symbol.getName();
if (seenSymbols.insert(name).second)
symbols.push_back(name);
}
}
}

LogicalResult wave::verifyWaveIndexMappings(Operation *op) {
// The attribute is optional.
Attribute attribute =
Expand All @@ -46,6 +79,15 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) {
dicts.push_back(dict);
}

// Collect all symbols from the index attribute.
llvm::StringSet<> indexSymbols;
for (DictionaryAttr dictAttr : dicts) {
for (auto named : dictAttr) {
indexSymbols.insert(named.getName().strref());
}
}

// Validate the index attribute structure and iterator symbols.
for (DictionaryAttr dictAttr : dicts) {
for (auto named : dictAttr) {
auto val = named.getValue();
Expand Down Expand Up @@ -84,6 +126,20 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) {
}
}
}

// Collect all symbols used in operands and results.
SmallVector<StringRef> requiredSymbols;
collectSymbolsFromTypes(op, requiredSymbols);

// Check that all required symbols have an entry in the index attribute.
for (StringRef symbolName : requiredSymbols) {
if (!indexSymbols.contains(symbolName)) {
return op->emitError("'index' attribute does not provide a mapping for "
"symbol '@")
<< symbolName << "' used in operand or result types";
}
}

return success();
}

Expand Down
27 changes: 27 additions & 0 deletions water/test/Dialect/Wave/ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -902,3 +902,30 @@ func.func @permute_result_not_permutation(%arg0: !wave.tensor<[@M, @N] of f32, <
wave.permute %arg0 : !wave.tensor<[@M, @N] of f32, <register>> to !wave.tensor<[@N, @K] of f32, <register>>
return
}

// -----

// Test that index attribute must provide mappings for all symbols used in operands
func.func @index_attr_missing_operand_symbol(%arg0: !wave.tensor<[@M, @N] of f32, <register>>) {
// expected-error @below {{'index' attribute does not provide a mapping for symbol '@N' used in operand or result types}}
%0 = wave.add %arg0, %arg0 index [{M : <[] -> (0, 1, 1)>}] : (!wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}

// -----

// Test that index attribute must provide mappings for all symbols used in results
func.func @index_attr_missing_result_symbol(%arg0: f32) {
// expected-error @below {{'index' attribute does not provide a mapping for symbol '@B' used in operand or result types}}
%0 = wave.register %arg0 index [{A : <[] -> (0, 1, 1)>}] : !wave.tensor<[@A, @B] of f32, <register>>
return
}

// -----

// Test that index attribute must provide mappings for all symbols from multiple dimensions
func.func @index_attr_missing_multiple_symbols(%arg0: !wave.tensor<[@M, @N, @K] of f32, <register>>) {
// expected-error @below {{'index' attribute does not provide a mapping for symbol '@N' used in operand or result types}}
%0 = wave.cast %arg0 index [{M : <[] -> (0, 1, 1)>, K : <[] -> (0, 1, 1)>}] : !wave.tensor<[@M, @N, @K] of f32, <register>> to !wave.tensor<[@M, @N, @K] of bf16, <register>>
return
}
Loading