diff --git a/_codeql_detected_source_root b/_codeql_detected_source_root new file mode 120000 index 000000000..945c9b46d --- /dev/null +++ b/_codeql_detected_source_root @@ -0,0 +1 @@ +. \ No newline at end of file diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index e89a06e65..fa0585fb0 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -27,6 +27,39 @@ using namespace mlir; // Index attribute verification //----------------------------------------------------------------------------- +// Helper function to collect all symbols used in the operation's operands and +// results. +static void collectSymbolsFromTypes(Operation *op, + llvm::SmallVectorImpl &symbols) { + llvm::StringSet<> seenSymbols; + + // Collect symbols from operand types. + for (Value operand : op->getOperands()) { + auto tensorType = llvm::dyn_cast(operand.getType()); + if (!tensorType || !tensorType.getFullySpecified()) + continue; + + for (wave::WaveSymbolAttr symbol : tensorType.getShape()) { + StringRef name = symbol.getName(); + if (seenSymbols.insert(name).second) + symbols.push_back(name); + } + } + + // Collect symbols from result types. + for (Value result : op->getResults()) { + auto tensorType = llvm::dyn_cast(result.getType()); + if (!tensorType || !tensorType.getFullySpecified()) + continue; + + for (wave::WaveSymbolAttr symbol : tensorType.getShape()) { + StringRef name = symbol.getName(); + if (seenSymbols.insert(name).second) + symbols.push_back(name); + } + } +} + LogicalResult wave::verifyWaveIndexMappings(Operation *op) { // The attribute is optional. Attribute attribute = @@ -46,6 +79,15 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { dicts.push_back(dict); } + // Collect all symbols from the index attribute. + llvm::StringSet<> indexSymbols; + for (DictionaryAttr dictAttr : dicts) { + for (auto named : dictAttr) { + indexSymbols.insert(named.getName().strref()); + } + } + + // Validate the index attribute structure and iterator symbols. for (DictionaryAttr dictAttr : dicts) { for (auto named : dictAttr) { auto val = named.getValue(); @@ -84,6 +126,20 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { } } } + + // Collect all symbols used in operands and results. + SmallVector requiredSymbols; + collectSymbolsFromTypes(op, requiredSymbols); + + // Check that all required symbols have an entry in the index attribute. + for (StringRef symbolName : requiredSymbols) { + if (!indexSymbols.contains(symbolName)) { + return op->emitError("'index' attribute does not provide a mapping for " + "symbol '@") + << symbolName << "' used in operand or result types"; + } + } + return success(); } diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index 9d7d8f0ee..7f4b1d50c 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -902,3 +902,30 @@ func.func @permute_result_not_permutation(%arg0: !wave.tensor<[@M, @N] of f32, < wave.permute %arg0 : !wave.tensor<[@M, @N] of f32, > to !wave.tensor<[@N, @K] of f32, > return } + +// ----- + +// Test that index attribute must provide mappings for all symbols used in operands +func.func @index_attr_missing_operand_symbol(%arg0: !wave.tensor<[@M, @N] of f32, >) { + // expected-error @below {{'index' attribute does not provide a mapping for symbol '@N' used in operand or result types}} + %0 = wave.add %arg0, %arg0 index [{M : <[] -> (0, 1, 1)>}] : (!wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +// ----- + +// Test that index attribute must provide mappings for all symbols used in results +func.func @index_attr_missing_result_symbol(%arg0: f32) { + // expected-error @below {{'index' attribute does not provide a mapping for symbol '@B' used in operand or result types}} + %0 = wave.register %arg0 index [{A : <[] -> (0, 1, 1)>}] : !wave.tensor<[@A, @B] of f32, > + return +} + +// ----- + +// Test that index attribute must provide mappings for all symbols from multiple dimensions +func.func @index_attr_missing_multiple_symbols(%arg0: !wave.tensor<[@M, @N, @K] of f32, >) { + // expected-error @below {{'index' attribute does not provide a mapping for symbol '@N' used in operand or result types}} + %0 = wave.cast %arg0 index [{M : <[] -> (0, 1, 1)>, K : <[] -> (0, 1, 1)>}] : !wave.tensor<[@M, @N, @K] of f32, > to !wave.tensor<[@M, @N, @K] of bf16, > + return +}