From 49017b76c5484e4c1f4b3815a423cf81ba4739b9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Feb 2026 14:04:59 +0000 Subject: [PATCH 1/4] Initial plan From 14d6b1d5de6d5ddc509703cceadce60fa7658fc3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Feb 2026 14:11:06 +0000 Subject: [PATCH 2/4] Add index attribute verifier for symbolic dimensions Implement verification that ensures the index attribute provides mappings for all symbols used in operand and result types. - Add collectSymbolsFromTypes helper to extract symbols from types - Extend verifyWaveIndexMappings to check all required symbols - Add lit tests in ops-invalid.mlir for missing symbol cases Co-authored-by: ftynse <1512299+ftynse@users.noreply.github.com> --- water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | 51 ++++++++++++++++++++ water/test/Dialect/Wave/ops-invalid.mlir | 27 +++++++++++ 2 files changed, 78 insertions(+) diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index e89a06e65..137028225 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -27,6 +27,39 @@ using namespace mlir; // Index attribute verification //----------------------------------------------------------------------------- +// Helper function to collect all symbols used in the operation's operands and +// results. +static void collectSymbolsFromTypes(Operation *op, + llvm::SmallVectorImpl &symbols) { + llvm::StringSet<> seenSymbols; + + // Collect symbols from operand types. + for (Value operand : op->getOperands()) { + auto tensorType = llvm::dyn_cast(operand.getType()); + if (!tensorType || !tensorType.getFullySpecified()) + continue; + + for (wave::WaveSymbolAttr symbol : tensorType.getShape()) { + StringRef name = symbol.getName(); + if (seenSymbols.insert(name).second) + symbols.push_back(name); + } + } + + // Collect symbols from result types. + for (Value result : op->getResults()) { + auto tensorType = llvm::dyn_cast(result.getType()); + if (!tensorType || !tensorType.getFullySpecified()) + continue; + + for (wave::WaveSymbolAttr symbol : tensorType.getShape()) { + StringRef name = symbol.getName(); + if (seenSymbols.insert(name).second) + symbols.push_back(name); + } + } +} + LogicalResult wave::verifyWaveIndexMappings(Operation *op) { // The attribute is optional. Attribute attribute = @@ -46,8 +79,12 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { dicts.push_back(dict); } + // Collect all symbols from the index attribute. + llvm::StringSet<> indexSymbols; for (DictionaryAttr dictAttr : dicts) { for (auto named : dictAttr) { + indexSymbols.insert(named.getName().strref()); + auto val = named.getValue(); if (!isa(val)) return op->emitError("'index' attribute value for key ") @@ -84,6 +121,20 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { } } } + + // Collect all symbols used in operands and results. + SmallVector requiredSymbols; + collectSymbolsFromTypes(op, requiredSymbols); + + // Check that all required symbols have an entry in the index attribute. + for (StringRef symbolName : requiredSymbols) { + if (!indexSymbols.contains(symbolName)) { + return op->emitError("'index' attribute does not provide a mapping for " + "symbol '@") + << symbolName << "' used in operand or result types"; + } + } + return success(); } diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index 9d7d8f0ee..7f4b1d50c 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -902,3 +902,30 @@ func.func @permute_result_not_permutation(%arg0: !wave.tensor<[@M, @N] of f32, < wave.permute %arg0 : !wave.tensor<[@M, @N] of f32, > to !wave.tensor<[@N, @K] of f32, > return } + +// ----- + +// Test that index attribute must provide mappings for all symbols used in operands +func.func @index_attr_missing_operand_symbol(%arg0: !wave.tensor<[@M, @N] of f32, >) { + // expected-error @below {{'index' attribute does not provide a mapping for symbol '@N' used in operand or result types}} + %0 = wave.add %arg0, %arg0 index [{M : <[] -> (0, 1, 1)>}] : (!wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + return +} + +// ----- + +// Test that index attribute must provide mappings for all symbols used in results +func.func @index_attr_missing_result_symbol(%arg0: f32) { + // expected-error @below {{'index' attribute does not provide a mapping for symbol '@B' used in operand or result types}} + %0 = wave.register %arg0 index [{A : <[] -> (0, 1, 1)>}] : !wave.tensor<[@A, @B] of f32, > + return +} + +// ----- + +// Test that index attribute must provide mappings for all symbols from multiple dimensions +func.func @index_attr_missing_multiple_symbols(%arg0: !wave.tensor<[@M, @N, @K] of f32, >) { + // expected-error @below {{'index' attribute does not provide a mapping for symbol '@N' used in operand or result types}} + %0 = wave.cast %arg0 index [{M : <[] -> (0, 1, 1)>, K : <[] -> (0, 1, 1)>}] : !wave.tensor<[@M, @N, @K] of f32, > to !wave.tensor<[@M, @N, @K] of bf16, > + return +} From 819c0a5036c4341a4d33d55f7b5bf09c3dd2c71c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Feb 2026 14:12:05 +0000 Subject: [PATCH 3/4] Improve code organization in verifyWaveIndexMappings Separate symbol collection from index attribute from validation logic for better readability and separation of concerns. Co-authored-by: ftynse <1512299+ftynse@users.noreply.github.com> --- water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index 137028225..fa0585fb0 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -84,7 +84,12 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { for (DictionaryAttr dictAttr : dicts) { for (auto named : dictAttr) { indexSymbols.insert(named.getName().strref()); + } + } + // Validate the index attribute structure and iterator symbols. + for (DictionaryAttr dictAttr : dicts) { + for (auto named : dictAttr) { auto val = named.getValue(); if (!isa(val)) return op->emitError("'index' attribute value for key ") From 17a14667eb1891c03bfe4305c94edacd1ae1b7e6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Feb 2026 14:13:00 +0000 Subject: [PATCH 4/4] Final implementation complete Co-authored-by: ftynse <1512299+ftynse@users.noreply.github.com> --- _codeql_detected_source_root | 1 + 1 file changed, 1 insertion(+) create mode 120000 _codeql_detected_source_root diff --git a/_codeql_detected_source_root b/_codeql_detected_source_root new file mode 120000 index 000000000..945c9b46d --- /dev/null +++ b/_codeql_detected_source_root @@ -0,0 +1 @@ +. \ No newline at end of file