From 50f5febb548da09e08bfb26c93d8742246a64aeb Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Tue, 16 Dec 2025 12:48:53 -0300 Subject: [PATCH 1/5] refactor: Use Visitor pattern for compiler passes --- crates/plotnik-lib/src/query/alt_kinds.rs | 48 +++-- crates/plotnik-lib/src/query/expr_arity.rs | 179 ++++++++++++++++++ .../{shapes_tests.rs => expr_arity_tests.rs} | 8 +- crates/plotnik-lib/src/query/mod.rs | 60 ++---- crates/plotnik-lib/src/query/printer.rs | 10 +- crates/plotnik-lib/src/query/shapes.rs | 131 ------------- crates/plotnik-lib/src/query/visitor.rs | 135 +++++++++++++ 7 files changed, 362 insertions(+), 209 deletions(-) create mode 100644 crates/plotnik-lib/src/query/expr_arity.rs rename crates/plotnik-lib/src/query/{shapes_tests.rs => expr_arity_tests.rs} (98%) delete mode 100644 crates/plotnik-lib/src/query/shapes.rs create mode 100644 crates/plotnik-lib/src/query/visitor.rs diff --git a/crates/plotnik-lib/src/query/alt_kinds.rs b/crates/plotnik-lib/src/query/alt_kinds.rs index d305954e..26968472 100644 --- a/crates/plotnik-lib/src/query/alt_kinds.rs +++ b/crates/plotnik-lib/src/query/alt_kinds.rs @@ -7,37 +7,43 @@ use rowan::TextRange; use super::Query; use super::invariants::ensure_both_branch_kinds; -use crate::diagnostics::DiagnosticKind; -use crate::parser::{AltExpr, AltKind, Branch, Expr}; +use super::visitor::{Visitor, walk_alt_expr, walk_root}; +use crate::diagnostics::{DiagnosticKind, Diagnostics}; +use crate::parser::{AltExpr, AltKind, Branch, Root}; impl Query<'_> { pub(super) fn validate_alt_kinds(&mut self) { - let defs: Vec<_> = self.ast.defs().collect(); - for def in defs { - let Some(body) = def.body() else { continue }; - self.validate_alt_expr(&body); - } + let mut visitor = AltKindsValidator { + diagnostics: &mut self.alt_kind_diagnostics, + }; + visitor.visit_root(&self.ast); + } +} +struct AltKindsValidator<'a> { + diagnostics: &'a mut Diagnostics, +} + +impl Visitor for AltKindsValidator<'_> { + fn visit_root(&mut self, root: &Root) { assert!( - self.ast.exprs().next().is_none(), + root.exprs().next().is_none(), "alt_kind: unexpected bare Expr in Root (parser should wrap in Def)" ); + walk_root(self, root); } - fn validate_alt_expr(&mut self, expr: &Expr) { - if let Expr::AltExpr(alt) = expr { - self.check_mixed_alternation(alt); - assert!( - alt.exprs().next().is_none(), - "alt_kind: unexpected bare Expr in Alt (parser should wrap in Branch)" - ); - } - - for child in expr.children() { - self.validate_alt_expr(&child); - } + fn visit_alt_expr(&mut self, alt: &AltExpr) { + self.check_mixed_alternation(alt); + assert!( + alt.exprs().next().is_none(), + "alt_kind: unexpected bare Expr in Alt (parser should wrap in Branch)" + ); + walk_alt_expr(self, alt); } +} +impl AltKindsValidator<'_> { fn check_mixed_alternation(&mut self, alt: &AltExpr) { if alt.kind() != AltKind::Mixed { return; @@ -57,7 +63,7 @@ impl Query<'_> { let untagged_range = branch_range(untagged_branch); - self.alt_kind_diagnostics + self.diagnostics .report(DiagnosticKind::MixedAltBranches, untagged_range) .related_to("tagged branch here", tagged_range) .emit(); diff --git a/crates/plotnik-lib/src/query/expr_arity.rs b/crates/plotnik-lib/src/query/expr_arity.rs new file mode 100644 index 00000000..271fd597 --- /dev/null +++ b/crates/plotnik-lib/src/query/expr_arity.rs @@ -0,0 +1,179 @@ +//! Expression arity analysis for query expressions. +//! +//! Determines whether an expression matches a single node position (`One`) +//! or multiple sequential positions (`Many`). Used to validate field constraints: +//! `field: expr` requires `expr` to have `ExprArity::One`. +//! +//! `Invalid` marks nodes where cardinality cannot be determined (error nodes, +//! undefined refs, etc.). + +use super::Query; +use super::visitor::{Visitor, walk_expr, walk_field_expr}; +use crate::diagnostics::DiagnosticKind; +use crate::parser::{Expr, FieldExpr, Ref, SeqExpr, SyntaxKind, SyntaxNode, ast}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExprArity { + One, + Many, + Invalid, +} + +impl Query<'_> { + pub(super) fn infer_arities(&mut self) { + let root = self.ast.clone(); + + let mut computer = ArityComputer { query: self }; + computer.visit_root(&root); + + let mut validator = ArityValidator { query: self }; + validator.visit_root(&root); + } + + pub(super) fn shape_arity(&self, node: &SyntaxNode) -> ExprArity { + // Error nodes are invalid + if node.kind() == SyntaxKind::Error { + return ExprArity::Invalid; + } + + // Root: cardinality based on definition count + if let Some(root) = ast::Root::cast(node.clone()) { + return if root.defs().count() > 1 { + ExprArity::Many + } else { + ExprArity::One + }; + } + + // Def: delegate to body's cardinality + if let Some(def) = ast::Def::cast(node.clone()) { + return def + .body() + .and_then(|b| self.expr_arity_table.get(&b).copied()) + .unwrap_or(ExprArity::Invalid); + } + + // Branch: delegate to body's cardinality + if let Some(branch) = ast::Branch::cast(node.clone()) { + return branch + .body() + .and_then(|b| self.expr_arity_table.get(&b).copied()) + .unwrap_or(ExprArity::Invalid); + } + + // Expr: direct lookup + ast::Expr::cast(node.clone()) + .and_then(|e| self.expr_arity_table.get(&e).copied()) + .unwrap_or(ExprArity::One) + } +} + +struct ArityComputer<'a, 'q> { + query: &'a mut Query<'q>, +} + +impl Visitor for ArityComputer<'_, '_> { + fn visit_expr(&mut self, expr: &Expr) { + self.query.compute_cardinality(expr); + walk_expr(self, expr); + } +} + +struct ArityValidator<'a, 'q> { + query: &'a mut Query<'q>, +} + +impl Visitor for ArityValidator<'_, '_> { + fn visit_field_expr(&mut self, field: &FieldExpr) { + self.query.validate_field(field); + walk_field_expr(self, field); + } +} + +impl Query<'_> { + fn compute_cardinality(&mut self, expr: &Expr) -> ExprArity { + if let Some(&c) = self.expr_arity_table.get(expr) { + return c; + } + // Insert sentinel to break cycles (e.g., `Foo = (Foo)`) + self.expr_arity_table + .insert(expr.clone(), ExprArity::Invalid); + let c = self.compute_single_cardinality(expr); + self.expr_arity_table.insert(expr.clone(), c); + c + } + + fn compute_single_cardinality(&mut self, expr: &Expr) -> ExprArity { + match expr { + Expr::NamedNode(_) | Expr::AnonymousNode(_) | Expr::FieldExpr(_) | Expr::AltExpr(_) => { + ExprArity::One + } + + Expr::SeqExpr(seq) => self.seq_cardinality(seq), + + Expr::CapturedExpr(cap) => { + let Some(inner) = cap.inner() else { + return ExprArity::Invalid; + }; + self.compute_cardinality(&inner) + } + + Expr::QuantifiedExpr(q) => { + let Some(inner) = q.inner() else { + return ExprArity::Invalid; + }; + self.compute_cardinality(&inner) + } + + Expr::Ref(r) => self.ref_cardinality(r), + } + } + + fn seq_cardinality(&mut self, seq: &SeqExpr) -> ExprArity { + let children: Vec<_> = seq.children().collect(); + + match children.len() { + 0 => ExprArity::One, + 1 => self.compute_cardinality(&children[0]), + _ => ExprArity::Many, + } + } + + fn ref_cardinality(&mut self, r: &Ref) -> ExprArity { + let name_tok = r.name().expect( + "expr_arities: Ref without name token \ + (parser only creates Ref for PascalCase Id)", + ); + let name = name_tok.text(); + + let Some(body) = self.symbol_table.get(name).cloned() else { + return ExprArity::Invalid; + }; + + self.compute_cardinality(&body) + } + + fn validate_field(&mut self, field: &FieldExpr) { + let Some(value) = field.value() else { + return; + }; + + let card = self + .expr_arity_table + .get(&value) + .copied() + .unwrap_or(ExprArity::One); + + if card == ExprArity::Many { + let field_name = field + .name() + .map(|t| t.text().to_string()) + .unwrap_or_else(|| "field".to_string()); + + self.expr_arity_diagnostics + .report(DiagnosticKind::FieldSequenceValue, value.text_range()) + .message(field_name) + .emit(); + } + } +} diff --git a/crates/plotnik-lib/src/query/shapes_tests.rs b/crates/plotnik-lib/src/query/expr_arity_tests.rs similarity index 98% rename from crates/plotnik-lib/src/query/shapes_tests.rs rename to crates/plotnik-lib/src/query/expr_arity_tests.rs index 392c99ae..d55cd524 100644 --- a/crates/plotnik-lib/src/query/shapes_tests.rs +++ b/crates/plotnik-lib/src/query/expr_arity_tests.rs @@ -203,7 +203,7 @@ fn field_with_ref_to_seq_error() { } #[test] -fn quantifier_preserves_inner_shape() { +fn quantifier_preserves_inner_arity() { let query = Query::try_from("(identifier)*").unwrap(); assert!(query.is_valid()); insta::assert_snapshot!(query.dump_with_cardinalities(), @r" @@ -215,7 +215,7 @@ fn quantifier_preserves_inner_shape() { } #[test] -fn capture_preserves_inner_shape() { +fn capture_preserves_inner_arity() { let query = Query::try_from("(identifier) @name").unwrap(); assert!(query.is_valid()); insta::assert_snapshot!(query.dump_with_cardinalities(), @r" @@ -241,7 +241,7 @@ fn capture_on_seq() { } #[test] -fn complex_nested_shapes() { +fn complex_nested_arities() { let input = indoc! {r#" Stmt = [(expr_stmt) (return_stmt)] (function_definition @@ -272,7 +272,7 @@ fn complex_nested_shapes() { } #[test] -fn tagged_alt_shapes() { +fn tagged_alt_arities() { let input = indoc! {r#" [Ident: (identifier) Num: (number)] "#}; diff --git a/crates/plotnik-lib/src/query/mod.rs b/crates/plotnik-lib/src/query/mod.rs index 9f11e252..d1114a46 100644 --- a/crates/plotnik-lib/src/query/mod.rs +++ b/crates/plotnik-lib/src/query/mod.rs @@ -14,6 +14,7 @@ mod printer; pub use printer::QueryPrinter; pub mod alt_kinds; +pub mod expr_arity; pub mod graph; mod graph_build; mod graph_dump; @@ -23,8 +24,8 @@ mod infer_dump; #[cfg(feature = "plotnik-langs")] pub mod link; pub mod recursion; -pub mod shapes; pub mod symbol_table; +pub mod visitor; pub use graph::{BuildEffect, BuildGraph, BuildMatcher, BuildNode, Fragment, NodeId, RefMarker}; pub use graph_optimize::OptimizeStats; @@ -36,6 +37,8 @@ pub use symbol_table::UNNAMED_DEF; #[cfg(test)] mod alt_kinds_tests; #[cfg(test)] +mod expr_arity_tests; +#[cfg(test)] mod graph_build_tests; #[cfg(test)] mod graph_master_test; @@ -52,8 +55,6 @@ mod printer_tests; #[cfg(test)] mod recursion_tests; #[cfg(test)] -mod shapes_tests; -#[cfg(test)] mod symbol_table_tests; use std::collections::{HashMap, HashSet}; @@ -72,7 +73,7 @@ use crate::parser::{ParseResult, Parser, Root, SyntaxNode, ast}; const DEFAULT_EXEC_FUEL: u32 = 1_000_000; const DEFAULT_RECURSION_FUEL: u32 = 4096; -use shapes::ShapeCardinality; +use expr_arity::ExprArity; use symbol_table::SymbolTable; /// A parsed and analyzed query. @@ -99,7 +100,7 @@ pub struct Query<'a> { source: &'a str, ast: Root, symbol_table: SymbolTable<'a>, - shape_cardinality_table: HashMap, + expr_arity_table: HashMap, #[cfg(feature = "plotnik-langs")] node_type_ids: HashMap<&'a str, Option>, #[cfg(feature = "plotnik-langs")] @@ -111,7 +112,7 @@ pub struct Query<'a> { alt_kind_diagnostics: Diagnostics, resolve_diagnostics: Diagnostics, recursion_diagnostics: Diagnostics, - shapes_diagnostics: Diagnostics, + expr_arity_diagnostics: Diagnostics, #[cfg(feature = "plotnik-langs")] link_diagnostics: Diagnostics, // Graph compilation fields @@ -147,7 +148,7 @@ impl<'a> Query<'a> { source, ast: empty_root(), symbol_table: SymbolTable::default(), - shape_cardinality_table: HashMap::new(), + expr_arity_table: HashMap::new(), #[cfg(feature = "plotnik-langs")] node_type_ids: HashMap::new(), #[cfg(feature = "plotnik-langs")] @@ -159,7 +160,7 @@ impl<'a> Query<'a> { alt_kind_diagnostics: Diagnostics::new(), resolve_diagnostics: Diagnostics::new(), recursion_diagnostics: Diagnostics::new(), - shapes_diagnostics: Diagnostics::new(), + expr_arity_diagnostics: Diagnostics::new(), #[cfg(feature = "plotnik-langs")] link_diagnostics: Diagnostics::new(), graph: BuildGraph::default(), @@ -200,7 +201,7 @@ impl<'a> Query<'a> { self.validate_alt_kinds(); self.resolve_names(); self.validate_recursion(); - self.infer_shapes(); + self.infer_arities(); Ok(self) } @@ -293,43 +294,6 @@ impl<'a> Query<'a> { &self.type_info } - pub(crate) fn shape_cardinality(&self, node: &SyntaxNode) -> ShapeCardinality { - // Error nodes are invalid - if node.kind() == SyntaxKind::Error { - return ShapeCardinality::Invalid; - } - - // Root: cardinality based on definition count - if let Some(root) = Root::cast(node.clone()) { - return if root.defs().count() > 1 { - ShapeCardinality::Many - } else { - ShapeCardinality::One - }; - } - - // Def: delegate to body's cardinality - if let Some(def) = ast::Def::cast(node.clone()) { - return def - .body() - .and_then(|b| self.shape_cardinality_table.get(&b).copied()) - .unwrap_or(ShapeCardinality::Invalid); - } - - // Branch: delegate to body's cardinality - if let Some(branch) = ast::Branch::cast(node.clone()) { - return branch - .body() - .and_then(|b| self.shape_cardinality_table.get(&b).copied()) - .unwrap_or(ShapeCardinality::Invalid); - } - - // Expr: direct lookup - ast::Expr::cast(node.clone()) - .and_then(|e| self.shape_cardinality_table.get(&e).copied()) - .unwrap_or(ShapeCardinality::One) - } - /// All diagnostics combined from all passes (unfiltered). /// /// Use this for debugging or when you need to see all diagnostics @@ -340,7 +304,7 @@ impl<'a> Query<'a> { all.extend(self.alt_kind_diagnostics.clone()); all.extend(self.resolve_diagnostics.clone()); all.extend(self.recursion_diagnostics.clone()); - all.extend(self.shapes_diagnostics.clone()); + all.extend(self.expr_arity_diagnostics.clone()); #[cfg(feature = "plotnik-langs")] all.extend(self.link_diagnostics.clone()); all.extend(self.type_info.diagnostics.clone()); @@ -362,7 +326,7 @@ impl<'a> Query<'a> { && !self.alt_kind_diagnostics.has_errors() && !self.resolve_diagnostics.has_errors() && !self.recursion_diagnostics.has_errors() - && !self.shapes_diagnostics.has_errors() + && !self.expr_arity_diagnostics.has_errors() && !self.link_diagnostics.has_errors() } diff --git a/crates/plotnik-lib/src/query/printer.rs b/crates/plotnik-lib/src/query/printer.rs index 2f02712e..b829bbd7 100644 --- a/crates/plotnik-lib/src/query/printer.rs +++ b/crates/plotnik-lib/src/query/printer.rs @@ -8,7 +8,7 @@ use rowan::NodeOrToken; use crate::parser::{self as ast, Expr, SyntaxNode}; use super::Query; -use super::shapes::ShapeCardinality; +use super::expr_arity::ExprArity; pub struct QueryPrinter<'q, 'src> { query: &'q Query<'src>, @@ -348,10 +348,10 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { if !self.cardinalities { return ""; } - match self.query.shape_cardinality(node) { - ShapeCardinality::One => "¹", - ShapeCardinality::Many => "⁺", - ShapeCardinality::Invalid => "⁻", + match self.query.shape_arity(node) { + ExprArity::One => "¹", + ExprArity::Many => "⁺", + ExprArity::Invalid => "⁻", } } diff --git a/crates/plotnik-lib/src/query/shapes.rs b/crates/plotnik-lib/src/query/shapes.rs deleted file mode 100644 index 31ec8c28..00000000 --- a/crates/plotnik-lib/src/query/shapes.rs +++ /dev/null @@ -1,131 +0,0 @@ -//! Shape cardinality analysis for query expressions. -//! -//! Determines whether an expression matches a single node position (`One`) -//! or multiple sequential positions (`Many`). Used to validate field constraints: -//! `field: expr` requires `expr` to have `ShapeCardinality::One`. -//! -//! `Invalid` marks nodes where cardinality cannot be determined (error nodes, -//! undefined refs, etc.). - -use super::Query; -use crate::diagnostics::DiagnosticKind; -use crate::parser::{Expr, Ref, SeqExpr}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum ShapeCardinality { - One, - Many, - Invalid, -} - -impl Query<'_> { - pub(super) fn infer_shapes(&mut self) { - let bodies: Vec<_> = self.ast.defs().filter_map(|d| d.body()).collect(); - - for body in &bodies { - self.compute_all_cardinalities(body); - } - - for body in &bodies { - self.validate_shapes(body); - } - } - - fn compute_all_cardinalities(&mut self, expr: &Expr) { - self.get_or_compute(expr); - - for child in expr.children() { - self.compute_all_cardinalities(&child); - } - } - - fn get_or_compute(&mut self, expr: &Expr) -> ShapeCardinality { - if let Some(&c) = self.shape_cardinality_table.get(expr) { - return c; - } - // Insert sentinel to break cycles (e.g., `Foo = (Foo)`) - self.shape_cardinality_table - .insert(expr.clone(), ShapeCardinality::Invalid); - let c = self.compute_single(expr); - self.shape_cardinality_table.insert(expr.clone(), c); - c - } - - fn compute_single(&mut self, expr: &Expr) -> ShapeCardinality { - match expr { - Expr::NamedNode(_) | Expr::AnonymousNode(_) | Expr::FieldExpr(_) | Expr::AltExpr(_) => { - ShapeCardinality::One - } - - Expr::SeqExpr(seq) => self.seq_cardinality(seq), - - Expr::CapturedExpr(cap) => { - let Some(inner) = cap.inner() else { - return ShapeCardinality::Invalid; - }; - self.get_or_compute(&inner) - } - - Expr::QuantifiedExpr(q) => { - let Some(inner) = q.inner() else { - return ShapeCardinality::Invalid; - }; - self.get_or_compute(&inner) - } - - Expr::Ref(r) => self.ref_cardinality(r), - } - } - - fn seq_cardinality(&mut self, seq: &SeqExpr) -> ShapeCardinality { - let children: Vec<_> = seq.children().collect(); - - match children.len() { - 0 => ShapeCardinality::One, - 1 => self.get_or_compute(&children[0]), - _ => ShapeCardinality::Many, - } - } - - fn ref_cardinality(&mut self, r: &Ref) -> ShapeCardinality { - let name_tok = r.name().expect( - "shape_cardinalities: Ref without name token \ - (parser only creates Ref for PascalCase Id)", - ); - let name = name_tok.text(); - - let Some(body) = self.symbol_table.get(name).cloned() else { - return ShapeCardinality::Invalid; - }; - - self.get_or_compute(&body) - } - - fn validate_shapes(&mut self, expr: &Expr) { - if let Expr::FieldExpr(field) = expr - && let Some(value) = field.value() - { - let card = self - .shape_cardinality_table - .get(&value) - .copied() - .unwrap_or(ShapeCardinality::One); - - if card == ShapeCardinality::Many { - let field_name = field - .name() - .map(|t| t.text().to_string()) - .unwrap_or_else(|| "field".to_string()); - - self.shapes_diagnostics - .report(DiagnosticKind::FieldSequenceValue, value.text_range()) - .message(field_name) - .emit(); - } - } - - for child in expr.children() { - self.validate_shapes(&child); - } - } -} diff --git a/crates/plotnik-lib/src/query/visitor.rs b/crates/plotnik-lib/src/query/visitor.rs new file mode 100644 index 00000000..927befeb --- /dev/null +++ b/crates/plotnik-lib/src/query/visitor.rs @@ -0,0 +1,135 @@ +//! AST Visitor pattern. +//! +//! # Usage +//! +//! Implement `Visitor` for your struct. Override `visit_*` methods to add logic. +//! Call `walk_*` within your override to continue recursion (or omit it to stop). +//! +//! ```ignore +//! impl Visitor for MyPass { +//! fn visit_named_node(&mut self, node: &NamedNode) { +//! // Pre-order logic +//! walk_named_node(self, node); +//! // Post-order logic +//! } +//! } +//! ``` + +use crate::parser::ast::{ + AltExpr, AnonymousNode, CapturedExpr, Def, Expr, FieldExpr, NamedNode, QuantifiedExpr, Ref, + Root, SeqExpr, +}; + +pub trait Visitor: Sized { + fn visit_root(&mut self, root: &Root) { + walk_root(self, root); + } + + fn visit_def(&mut self, def: &Def) { + walk_def(self, def); + } + + fn visit_expr(&mut self, expr: &Expr) { + walk_expr(self, expr); + } + + fn visit_named_node(&mut self, node: &NamedNode) { + walk_named_node(self, node); + } + + fn visit_anonymous_node(&mut self, _node: &AnonymousNode) { + // Leaf node + } + + fn visit_ref(&mut self, _ref: &Ref) { + // Leaf node in AST structure (semantic traversal happens via SymbolTable lookup) + } + + fn visit_alt_expr(&mut self, alt: &AltExpr) { + walk_alt_expr(self, alt); + } + + fn visit_seq_expr(&mut self, seq: &SeqExpr) { + walk_seq_expr(self, seq); + } + + fn visit_captured_expr(&mut self, cap: &CapturedExpr) { + walk_captured_expr(self, cap); + } + + fn visit_quantified_expr(&mut self, quant: &QuantifiedExpr) { + walk_quantified_expr(self, quant); + } + + fn visit_field_expr(&mut self, field: &FieldExpr) { + walk_field_expr(self, field); + } +} + +pub fn walk_root(visitor: &mut V, root: &Root) { + for def in root.defs() { + visitor.visit_def(&def); + } +} + +pub fn walk_def(visitor: &mut V, def: &Def) { + if let Some(body) = def.body() { + visitor.visit_expr(&body); + } +} + +pub fn walk_expr(visitor: &mut V, expr: &Expr) { + match expr { + Expr::NamedNode(n) => visitor.visit_named_node(n), + Expr::AnonymousNode(n) => visitor.visit_anonymous_node(n), + Expr::Ref(r) => visitor.visit_ref(r), + Expr::AltExpr(a) => visitor.visit_alt_expr(a), + Expr::SeqExpr(s) => visitor.visit_seq_expr(s), + Expr::CapturedExpr(c) => visitor.visit_captured_expr(c), + Expr::QuantifiedExpr(q) => visitor.visit_quantified_expr(q), + Expr::FieldExpr(f) => visitor.visit_field_expr(f), + } +} + +pub fn walk_named_node(visitor: &mut V, node: &NamedNode) { + // We iterate specific children to avoid Expr::children() Vec allocation + for child in node.children() { + visitor.visit_expr(&child); + } +} + +pub fn walk_alt_expr(visitor: &mut V, alt: &AltExpr) { + for branch in alt.branches() { + if let Some(body) = branch.body() { + visitor.visit_expr(&body); + } + } + // Also visit bare exprs in untagged/mixed alts if any exist unwrapped + for expr in alt.exprs() { + visitor.visit_expr(&expr); + } +} + +pub fn walk_seq_expr(visitor: &mut V, seq: &SeqExpr) { + for child in seq.children() { + visitor.visit_expr(&child); + } +} + +pub fn walk_captured_expr(visitor: &mut V, cap: &CapturedExpr) { + if let Some(inner) = cap.inner() { + visitor.visit_expr(&inner); + } +} + +pub fn walk_quantified_expr(visitor: &mut V, quant: &QuantifiedExpr) { + if let Some(inner) = quant.inner() { + visitor.visit_expr(&inner); + } +} + +pub fn walk_field_expr(visitor: &mut V, field: &FieldExpr) { + if let Some(val) = field.value() { + visitor.visit_expr(&val); + } +} From fbd57c65da89def3c76423150a894fb82608d49f Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Tue, 16 Dec 2025 13:27:33 -0300 Subject: [PATCH 2/5] Refactor expression arities --- AGENTS.md | 2 +- README.md | 2 +- crates/plotnik-cli/src/cli.rs | 4 +- crates/plotnik-cli/src/commands/debug/mod.rs | 6 +- crates/plotnik-cli/src/main.rs | 2 +- crates/plotnik-lib/src/query/dump.rs | 8 +- crates/plotnik-lib/src/query/expr_arity.rs | 98 ++++++++++--------- .../plotnik-lib/src/query/expr_arity_tests.rs | 70 ++++++------- crates/plotnik-lib/src/query/mod.rs | 22 ++--- crates/plotnik-lib/src/query/printer.rs | 33 ++++--- crates/plotnik-lib/src/query/printer_tests.rs | 8 +- 11 files changed, 129 insertions(+), 126 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index af856525..5005e79b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -192,7 +192,7 @@ Inputs: `-q/--query `, `--query-file `, `--source `, `-s/--source-file - `--cst` — Show query CST instead of AST - `--raw` — Include trivia tokens (whitespace, comments) - `--spans` — Show source spans -- `--cardinalities` — Show inferred cardinalities +- `--arities` — Show node arities - `--graph` — Show compiled transition graph - `--graph-raw` — Show unoptimized graph (before epsilon elimination) - `--types` — Show inferred types diff --git a/README.md b/README.md index f6a79634..d0a23109 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,7 @@ The schema infrastructure is built. Type inference is next. The CLI foundation exists. The full developer experience is ahead. - [x] CLI framework with `debug`, `docs`, `langs`, `exec`, `types` commands -- [x] Query inspection: AST dump, symbol table, cardinalities, spans, transition graph, inferred types +- [x] Query inspection: AST dump, symbol table, node arities, spans, transition graph, inferred types - [x] Source inspection: Tree-sitter parse tree visualization - [x] Execute queries against source code and output JSON (`exec`) - [x] Generate TypeScript types from queries (`types`) diff --git a/crates/plotnik-cli/src/cli.rs b/crates/plotnik-cli/src/cli.rs index 40663524..9ed47989 100644 --- a/crates/plotnik-cli/src/cli.rs +++ b/crates/plotnik-cli/src/cli.rs @@ -197,9 +197,9 @@ pub struct OutputArgs { #[arg(long)] pub spans: bool, - /// Show inferred cardinalities + /// Show inferred arities #[arg(long)] - pub cardinalities: bool, + pub arities: bool, /// Show compiled graph #[arg(long)] diff --git a/crates/plotnik-cli/src/commands/debug/mod.rs b/crates/plotnik-cli/src/commands/debug/mod.rs index 888b8625..1d1ad6fb 100644 --- a/crates/plotnik-cli/src/commands/debug/mod.rs +++ b/crates/plotnik-cli/src/commands/debug/mod.rs @@ -17,7 +17,7 @@ pub struct DebugArgs { pub raw: bool, pub cst: bool, pub spans: bool, - pub cardinalities: bool, + pub arities: bool, pub graph: bool, pub graph_raw: bool, pub types: bool, @@ -65,7 +65,7 @@ pub fn run(args: DebugArgs) { .raw(args.cst || args.raw) .with_trivia(args.raw) .with_spans(args.spans) - .with_cardinalities(args.cardinalities) + .with_arities(args.arities) .dump() ); } @@ -77,7 +77,7 @@ pub fn run(args: DebugArgs) { "{}", q.printer() .only_symbols(true) - .with_cardinalities(args.cardinalities) + .with_arities(args.arities) .dump() ); } diff --git a/crates/plotnik-cli/src/main.rs b/crates/plotnik-cli/src/main.rs index 18ec2e01..63cbedff 100644 --- a/crates/plotnik-cli/src/main.rs +++ b/crates/plotnik-cli/src/main.rs @@ -26,7 +26,7 @@ fn main() { raw: output.raw, cst: output.cst, spans: output.spans, - cardinalities: output.cardinalities, + arities: output.arities, graph: output.graph, graph_raw: output.graph_raw, types: output.types, diff --git a/crates/plotnik-lib/src/query/dump.rs b/crates/plotnik-lib/src/query/dump.rs index 9f2f7219..1b26a568 100644 --- a/crates/plotnik-lib/src/query/dump.rs +++ b/crates/plotnik-lib/src/query/dump.rs @@ -17,12 +17,12 @@ mod test_helpers { self.printer().dump() } - pub fn dump_with_cardinalities(&self) -> String { - self.printer().with_cardinalities(true).dump() + pub fn dump_with_arities(&self) -> String { + self.printer().with_arities(true).dump() } - pub fn dump_cst_with_cardinalities(&self) -> String { - self.printer().raw(true).with_cardinalities(true).dump() + pub fn dump_cst_with_arities(&self) -> String { + self.printer().raw(true).with_arities(true).dump() } pub fn dump_symbols(&self) -> String { diff --git a/crates/plotnik-lib/src/query/expr_arity.rs b/crates/plotnik-lib/src/query/expr_arity.rs index 271fd597..af2a3dd1 100644 --- a/crates/plotnik-lib/src/query/expr_arity.rs +++ b/crates/plotnik-lib/src/query/expr_arity.rs @@ -4,7 +4,7 @@ //! or multiple sequential positions (`Many`). Used to validate field constraints: //! `field: expr` requires `expr` to have `ExprArity::One`. //! -//! `Invalid` marks nodes where cardinality cannot be determined (error nodes, +//! `Invalid` marks nodes where arity cannot be determined (error nodes, //! undefined refs, etc.). use super::Query; @@ -30,41 +30,40 @@ impl Query<'_> { validator.visit_root(&root); } - pub(super) fn shape_arity(&self, node: &SyntaxNode) -> ExprArity { - // Error nodes are invalid + pub(super) fn get_arity(&self, node: &SyntaxNode) -> Option { if node.kind() == SyntaxKind::Error { - return ExprArity::Invalid; + return Some(ExprArity::Invalid); } - // Root: cardinality based on definition count + // Try casting to Expr first as it's the most common query + if let Some(expr) = ast::Expr::cast(node.clone()) { + return self.expr_arity_table.get(&expr).copied(); + } + + // Root: arity based on definition count if let Some(root) = ast::Root::cast(node.clone()) { - return if root.defs().count() > 1 { + return Some(if root.defs().nth(1).is_some() { ExprArity::Many } else { ExprArity::One - }; + }); } - // Def: delegate to body's cardinality + // Def: delegate to body's arity if let Some(def) = ast::Def::cast(node.clone()) { return def .body() - .and_then(|b| self.expr_arity_table.get(&b).copied()) - .unwrap_or(ExprArity::Invalid); + .and_then(|b| self.expr_arity_table.get(&b).copied()); } - // Branch: delegate to body's cardinality + // Branch: delegate to body's arity if let Some(branch) = ast::Branch::cast(node.clone()) { return branch .body() - .and_then(|b| self.expr_arity_table.get(&b).copied()) - .unwrap_or(ExprArity::Invalid); + .and_then(|b| self.expr_arity_table.get(&b).copied()); } - // Expr: direct lookup - ast::Expr::cast(node.clone()) - .and_then(|e| self.expr_arity_table.get(&e).copied()) - .unwrap_or(ExprArity::One) + None } } @@ -74,7 +73,7 @@ struct ArityComputer<'a, 'q> { impl Visitor for ArityComputer<'_, '_> { fn visit_expr(&mut self, expr: &Expr) { - self.query.compute_cardinality(expr); + self.query.compute_arity(expr); walk_expr(self, expr); } } @@ -91,66 +90,69 @@ impl Visitor for ArityValidator<'_, '_> { } impl Query<'_> { - fn compute_cardinality(&mut self, expr: &Expr) -> ExprArity { + fn compute_arity(&mut self, expr: &Expr) -> ExprArity { if let Some(&c) = self.expr_arity_table.get(expr) { return c; } // Insert sentinel to break cycles (e.g., `Foo = (Foo)`) self.expr_arity_table .insert(expr.clone(), ExprArity::Invalid); - let c = self.compute_single_cardinality(expr); + + let c = self.compute_single_arity(expr); self.expr_arity_table.insert(expr.clone(), c); c } - fn compute_single_cardinality(&mut self, expr: &Expr) -> ExprArity { + fn compute_single_arity(&mut self, expr: &Expr) -> ExprArity { match expr { Expr::NamedNode(_) | Expr::AnonymousNode(_) | Expr::FieldExpr(_) | Expr::AltExpr(_) => { ExprArity::One } - Expr::SeqExpr(seq) => self.seq_cardinality(seq), + Expr::SeqExpr(seq) => self.seq_arity(seq), - Expr::CapturedExpr(cap) => { - let Some(inner) = cap.inner() else { - return ExprArity::Invalid; - }; - self.compute_cardinality(&inner) - } + Expr::CapturedExpr(cap) => cap + .inner() + .map(|inner| self.compute_arity(&inner)) + .unwrap_or(ExprArity::Invalid), - Expr::QuantifiedExpr(q) => { - let Some(inner) = q.inner() else { - return ExprArity::Invalid; - }; - self.compute_cardinality(&inner) - } + Expr::QuantifiedExpr(q) => q + .inner() + .map(|inner| self.compute_arity(&inner)) + .unwrap_or(ExprArity::Invalid), - Expr::Ref(r) => self.ref_cardinality(r), + Expr::Ref(r) => self.ref_arity(r), } } - fn seq_cardinality(&mut self, seq: &SeqExpr) -> ExprArity { - let children: Vec<_> = seq.children().collect(); - - match children.len() { - 0 => ExprArity::One, - 1 => self.compute_cardinality(&children[0]), - _ => ExprArity::Many, + fn seq_arity(&mut self, seq: &SeqExpr) -> ExprArity { + // Avoid collecting into Vec; check if we have 0, 1, or >1 children. + let mut children = seq.children(); + + match children.next() { + None => ExprArity::One, + Some(first) => { + if children.next().is_some() { + ExprArity::Many + } else { + self.compute_arity(&first) + } + } } } - fn ref_cardinality(&mut self, r: &Ref) -> ExprArity { + fn ref_arity(&mut self, r: &Ref) -> ExprArity { let name_tok = r.name().expect( "expr_arities: Ref without name token \ (parser only creates Ref for PascalCase Id)", ); let name = name_tok.text(); - let Some(body) = self.symbol_table.get(name).cloned() else { - return ExprArity::Invalid; - }; - - self.compute_cardinality(&body) + self.symbol_table + .get(name) + .cloned() + .map(|body| self.compute_arity(&body)) + .unwrap_or(ExprArity::Invalid) } fn validate_field(&mut self, field: &FieldExpr) { diff --git a/crates/plotnik-lib/src/query/expr_arity_tests.rs b/crates/plotnik-lib/src/query/expr_arity_tests.rs index d55cd524..6859eba8 100644 --- a/crates/plotnik-lib/src/query/expr_arity_tests.rs +++ b/crates/plotnik-lib/src/query/expr_arity_tests.rs @@ -5,7 +5,7 @@ use indoc::indoc; fn tree_is_one() { let query = Query::try_from("(identifier)").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ identifier @@ -16,7 +16,7 @@ fn tree_is_one() { fn singleton_seq_is_one() { let query = Query::try_from("{(identifier)}").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ Seq¹ @@ -28,7 +28,7 @@ fn singleton_seq_is_one() { fn nested_singleton_seq_is_one() { let query = Query::try_from("{{{(identifier)}}}").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ Seq¹ @@ -42,7 +42,7 @@ fn nested_singleton_seq_is_one() { fn multi_seq_is_many() { let query = Query::try_from("{(a) (b)}").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def⁺ Seq⁺ @@ -55,7 +55,7 @@ fn multi_seq_is_many() { fn alt_is_one() { let query = Query::try_from("[(a) (b)]").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ Alt¹ @@ -73,7 +73,7 @@ fn alt_with_seq_branches() { "#}; let query = Query::try_from(input).unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ Alt¹ @@ -94,7 +94,7 @@ fn ref_to_tree_is_one() { "#}; let query = Query::try_from(input).unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root⁺ Def¹ X NamedNode¹ identifier @@ -112,7 +112,7 @@ fn ref_to_seq_is_many() { "#}; let query = Query::try_from(input).unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root⁺ Def⁺ X Seq⁺ @@ -128,7 +128,7 @@ fn ref_to_seq_is_many() { fn field_with_tree() { let query = Query::try_from("(call name: (identifier))").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ call @@ -141,7 +141,7 @@ fn field_with_tree() { fn field_with_alt() { let query = Query::try_from("(call name: [(identifier) (string)])").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ call @@ -158,7 +158,7 @@ fn field_with_alt() { fn field_with_seq_error() { let query = Query::try_from("(call name: {(a) (b)})").unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ call @@ -183,7 +183,7 @@ fn field_with_ref_to_seq_error() { "#}; let query = Query::try_from(input).unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root⁺ Def⁺ X Seq⁺ @@ -206,7 +206,7 @@ fn field_with_ref_to_seq_error() { fn quantifier_preserves_inner_arity() { let query = Query::try_from("(identifier)*").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ QuantifiedExpr¹ * @@ -218,7 +218,7 @@ fn quantifier_preserves_inner_arity() { fn capture_preserves_inner_arity() { let query = Query::try_from("(identifier) @name").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ CapturedExpr¹ @name @@ -230,7 +230,7 @@ fn capture_preserves_inner_arity() { fn capture_on_seq() { let query = Query::try_from("{(a) (b)} @items").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def⁺ CapturedExpr⁺ @items @@ -250,7 +250,7 @@ fn complex_nested_arities() { "#}; let query = Query::try_from(input).unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root⁺ Def¹ Stmt Alt¹ @@ -278,7 +278,7 @@ fn tagged_alt_arities() { "#}; let query = Query::try_from(input).unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ Alt¹ @@ -290,10 +290,10 @@ fn tagged_alt_arities() { } #[test] -fn anchor_has_no_cardinality() { +fn anchor_has_no_arity() { let query = Query::try_from("(block . (statement))").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ block @@ -303,10 +303,10 @@ fn anchor_has_no_cardinality() { } #[test] -fn negated_field_has_no_cardinality() { +fn negated_field_has_no_arity() { let query = Query::try_from("(function !async)").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ function @@ -318,7 +318,7 @@ fn negated_field_has_no_cardinality() { fn tree_with_wildcard_type() { let query = Query::try_from("(_)").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ (any) @@ -329,7 +329,7 @@ fn tree_with_wildcard_type() { fn bare_wildcard_is_one() { let query = Query::try_from("_").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ AnonymousNode¹ (any) @@ -340,7 +340,7 @@ fn bare_wildcard_is_one() { fn empty_seq_is_one() { let query = Query::try_from("{}").unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ Seq¹ @@ -351,7 +351,7 @@ fn empty_seq_is_one() { fn literal_is_one() { let query = Query::try_from(r#""if""#).unwrap(); assert!(query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r#" + insta::assert_snapshot!(query.dump_with_arities(), @r#" Root¹ Def¹ AnonymousNode¹ "if" @@ -362,7 +362,7 @@ fn literal_is_one() { fn invalid_error_node() { let query = Query::try_from("(foo %)").unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_cst_with_cardinalities(), @r#" + insta::assert_snapshot!(query.dump_cst_with_arities(), @r#" Root¹ Def¹ Tree¹ @@ -378,7 +378,7 @@ fn invalid_error_node() { fn invalid_undefined_ref() { let query = Query::try_from("(Undefined)").unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def⁻ Ref⁻ Undefined @@ -389,11 +389,11 @@ fn invalid_undefined_ref() { fn invalid_branch_without_body() { let query = Query::try_from("[A:]").unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ Alt¹ - Branch⁻ A: + Branchˣ A: "); } @@ -405,10 +405,10 @@ fn invalid_ref_to_bodyless_def() { "#}; let query = Query::try_from(input).unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root⁺ - Def⁻ X - Def⁻ + Defˣ X + Defˣ Def⁻ Ref⁻ X "); @@ -419,7 +419,7 @@ fn invalid_capture_without_inner() { // Error recovery: `extra` is invalid, but `@y` still creates a Capture node let query = Query::try_from("(call extra @y)").unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ call @@ -450,7 +450,7 @@ fn invalid_capture_without_inner_standalone() { fn invalid_multiple_captures_with_error() { let query = Query::try_from("(call (Undefined) @x extra @y)").unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ call @@ -465,7 +465,7 @@ fn invalid_quantifier_without_inner() { // Error recovery: `extra` is invalid, but `*` still creates a Quantifier node let query = Query::try_from("(foo extra*)").unwrap(); assert!(!query.is_valid()); - insta::assert_snapshot!(query.dump_with_cardinalities(), @r" + insta::assert_snapshot!(query.dump_with_arities(), @r" Root¹ Def¹ NamedNode¹ foo diff --git a/crates/plotnik-lib/src/query/mod.rs b/crates/plotnik-lib/src/query/mod.rs index d1114a46..001f1881 100644 --- a/crates/plotnik-lib/src/query/mod.rs +++ b/crates/plotnik-lib/src/query/mod.rs @@ -96,15 +96,15 @@ pub struct QisTrigger<'a> { } #[derive(Debug)] -pub struct Query<'a> { - source: &'a str, +pub struct Query<'q> { + source: &'q str, ast: Root, - symbol_table: SymbolTable<'a>, + symbol_table: SymbolTable<'q>, expr_arity_table: HashMap, #[cfg(feature = "plotnik-langs")] - node_type_ids: HashMap<&'a str, Option>, + node_type_ids: HashMap<&'q str, Option>, #[cfg(feature = "plotnik-langs")] - node_field_ids: HashMap<&'a str, Option>, + node_field_ids: HashMap<&'q str, Option>, exec_fuel: Option, recursion_fuel: Option, exec_fuel_consumed: u32, @@ -116,17 +116,17 @@ pub struct Query<'a> { #[cfg(feature = "plotnik-langs")] link_diagnostics: Diagnostics, // Graph compilation fields - graph: BuildGraph<'a>, + graph: BuildGraph<'q>, dead_nodes: HashSet, - type_info: TypeInferenceResult<'a>, + type_info: TypeInferenceResult<'q>, /// QIS triggers: quantified expressions with ≥2 propagating captures. - qis_triggers: HashMap>, + qis_triggers: HashMap>, /// Definitions with exactly 1 propagating capture: def name → capture name. - single_capture_defs: HashMap<&'a str, &'a str>, + single_capture_defs: HashMap<&'q str, &'q str>, /// Definitions with 2+ propagating captures (need struct wrapping at root). - multi_capture_defs: HashSet<&'a str>, + multi_capture_defs: HashSet<&'q str>, /// Current definition name during graph construction. - current_def_name: &'a str, + current_def_name: &'q str, /// Counter for generating unique ref IDs during graph construction. next_ref_id: u32, } diff --git a/crates/plotnik-lib/src/query/printer.rs b/crates/plotnik-lib/src/query/printer.rs index b829bbd7..dbbf5bfb 100644 --- a/crates/plotnik-lib/src/query/printer.rs +++ b/crates/plotnik-lib/src/query/printer.rs @@ -14,7 +14,7 @@ pub struct QueryPrinter<'q, 'src> { query: &'q Query<'src>, raw: bool, trivia: bool, - cardinalities: bool, + arities: bool, spans: bool, symbols: bool, } @@ -25,7 +25,7 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { query, raw: false, trivia: false, - cardinalities: false, + arities: false, spans: false, symbols: false, } @@ -41,8 +41,8 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { self } - pub fn with_cardinalities(mut self, value: bool) -> Self { - self.cardinalities = value; + pub fn with_arities(mut self, value: bool) -> Self { + self.arities = value; self } @@ -120,7 +120,7 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { let card = body_nodes .get(name) - .map(|n| self.cardinality_mark(n)) + .map(|n| self.arity_mark(n)) .unwrap_or(""); writeln!(w, "{}{}{}", prefix, name, card)?; visited.insert(name.to_string()); @@ -140,7 +140,7 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { fn format_cst(&self, node: &SyntaxNode, indent: usize, w: &mut impl Write) -> std::fmt::Result { let prefix = " ".repeat(indent); - let card = self.cardinality_mark(node); + let card = self.arity_mark(node); let span = self.span_str(node.text_range()); writeln!(w, "{}{:?}{}{}", prefix, node.kind(), card, span)?; @@ -169,7 +169,7 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { } fn format_root(&self, root: &ast::Root, w: &mut impl Write) -> std::fmt::Result { - let card = self.cardinality_mark(root.as_cst()); + let card = self.arity_mark(root.as_cst()); let span = self.span_str(root.text_range()); writeln!(w, "Root{}{}", card, span)?; @@ -186,7 +186,7 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { fn format_def(&self, def: &ast::Def, indent: usize, w: &mut impl Write) -> std::fmt::Result { let prefix = " ".repeat(indent); - let card = self.cardinality_mark(def.as_cst()); + let card = self.arity_mark(def.as_cst()); let span = self.span_str(def.text_range()); let name = def.name().map(|t| t.text().to_string()); @@ -203,7 +203,7 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { fn format_expr(&self, expr: &ast::Expr, indent: usize, w: &mut impl Write) -> std::fmt::Result { let prefix = " ".repeat(indent); - let card = self.cardinality_mark(expr.as_cst()); + let card = self.arity_mark(expr.as_cst()); let span = self.span_str(expr.text_range()); match expr { @@ -329,7 +329,7 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { w: &mut impl Write, ) -> std::fmt::Result { let prefix = " ".repeat(indent); - let card = self.cardinality_mark(branch.as_cst()); + let card = self.arity_mark(branch.as_cst()); let span = self.span_str(branch.text_range()); let label = branch.label().map(|t| t.text().to_string()); @@ -344,14 +344,15 @@ impl<'q, 'src> QueryPrinter<'q, 'src> { self.format_expr(&body, indent + 1, w) } - fn cardinality_mark(&self, node: &SyntaxNode) -> &'static str { - if !self.cardinalities { + fn arity_mark(&self, node: &SyntaxNode) -> &'static str { + if !self.arities { return ""; } - match self.query.shape_arity(node) { - ExprArity::One => "¹", - ExprArity::Many => "⁺", - ExprArity::Invalid => "⁻", + match self.query.get_arity(node) { + Some(ExprArity::One) => "¹", + Some(ExprArity::Many) => "⁺", + Some(ExprArity::Invalid) => "⁻", + None => "ˣ", } } diff --git a/crates/plotnik-lib/src/query/printer_tests.rs b/crates/plotnik-lib/src/query/printer_tests.rs index b7813b96..190d4ba3 100644 --- a/crates/plotnik-lib/src/query/printer_tests.rs +++ b/crates/plotnik-lib/src/query/printer_tests.rs @@ -12,9 +12,9 @@ fn printer_with_spans() { } #[test] -fn printer_with_cardinalities() { +fn printer_with_arities() { let q = Query::try_from("(call)").unwrap(); - insta::assert_snapshot!(q.printer().with_cardinalities(true).dump(), @r" + insta::assert_snapshot!(q.printer().with_arities(true).dump(), @r" Root¹ Def¹ NamedNode¹ call @@ -150,14 +150,14 @@ fn printer_ref() { } #[test] -fn printer_symbols_with_cardinalities() { +fn printer_symbols_with_arities() { let input = indoc! {r#" A = (a) B = {(b) (c)} (entry (A) (B)) "#}; let q = Query::try_from(input).unwrap(); - insta::assert_snapshot!(q.printer().only_symbols(true).with_cardinalities(true).dump(), @r" + insta::assert_snapshot!(q.printer().only_symbols(true).with_arities(true).dump(), @r" A¹ B⁺ _ From e58310575fd148fd6cc15453669d1fd4e6d9c5aa Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Tue, 16 Dec 2025 14:16:08 -0300 Subject: [PATCH 3/5] Update symbol_table.rs --- crates/plotnik-lib/src/query/symbol_table.rs | 80 ++++++++++---------- 1 file changed, 38 insertions(+), 42 deletions(-) diff --git a/crates/plotnik-lib/src/query/symbol_table.rs b/crates/plotnik-lib/src/query/symbol_table.rs index 6ab31455..88f85b93 100644 --- a/crates/plotnik-lib/src/query/symbol_table.rs +++ b/crates/plotnik-lib/src/query/symbol_table.rs @@ -11,9 +11,10 @@ use indexmap::IndexMap; pub const UNNAMED_DEF: &str = "_"; use crate::diagnostics::DiagnosticKind; -use crate::parser::{Expr, Ref, ast, token_src}; +use crate::parser::{ast, token_src}; use super::Query; +use super::visitor::{Visitor, walk_root}; pub type SymbolTable<'src> = IndexMap<&'src str, ast::Expr>; @@ -21,65 +22,60 @@ impl<'a> Query<'a> { pub(super) fn resolve_names(&mut self) { // Pass 1: collect definitions for def in self.ast.defs() { - let (name, is_named) = match def.name() { - Some(token) => (token_src(&token, self.source), true), - None => (UNNAMED_DEF, false), - }; - - // Skip duplicate check for unnamed definitions (already diagnosed by parser) - if is_named && self.symbol_table.contains_key(name) { - let name_token = def.name().unwrap(); - self.resolve_diagnostics - .report(DiagnosticKind::DuplicateDefinition, name_token.text_range()) - .message(name) - .emit(); - continue; - } + let Some(body) = def.body() else { continue }; - // For unnamed defs, only keep the last one (parser already warned about others) - if !is_named && self.symbol_table.contains_key(name) { - self.symbol_table.shift_remove(name); + if let Some(token) = def.name() { + // Named definition: `Name = ...` + let name = token_src(&token, self.source); + if self.symbol_table.contains_key(name) { + self.resolve_diagnostics + .report(DiagnosticKind::DuplicateDefinition, token.text_range()) + .message(name) + .emit(); + } else { + self.symbol_table.insert(name, body); + } + } else { + // Unnamed definition: `...` (root expression) + // Parser already validates multiple unnamed defs; we keep the last one. + if self.symbol_table.contains_key(UNNAMED_DEF) { + self.symbol_table.shift_remove(UNNAMED_DEF); + } + self.symbol_table.insert(UNNAMED_DEF, body); } - - let Some(body) = def.body() else { - continue; - }; - self.symbol_table.insert(name, body); } // Pass 2: check references - let defs: Vec<_> = self.ast.defs().collect(); - for def in defs { - let Some(body) = def.body() else { continue }; - self.collect_reference_diagnostics(&body); - } + let root = self.ast.clone(); + let mut validator = ReferenceValidator { query: self }; + validator.visit_root(&root); + } +} +struct ReferenceValidator<'a, 'q> { + query: &'a mut Query<'q>, +} + +impl Visitor for ReferenceValidator<'_, '_> { + fn visit_root(&mut self, root: &ast::Root) { // Parser wraps all top-level exprs in Def nodes, so this should be empty assert!( - self.ast.exprs().next().is_none(), + root.exprs().next().is_none(), "symbol_table: unexpected bare Expr in Root (parser should wrap in Def)" ); + walk_root(self, root); } - fn collect_reference_diagnostics(&mut self, expr: &Expr) { - if let Expr::Ref(r) = expr { - self.check_ref_diagnostic(r); - } - - for child in expr.children() { - self.collect_reference_diagnostics(&child); - } - } - - fn check_ref_diagnostic(&mut self, r: &Ref) { + fn visit_ref(&mut self, r: &ast::Ref) { let Some(name_token) = r.name() else { return }; let name = name_token.text(); - if self.symbol_table.contains_key(name) { + if self.query.symbol_table.contains_key(name) { return; } - self.resolve_diagnostics + self.query + .resolve_diagnostics .report(DiagnosticKind::UndefinedReference, name_token.text_range()) .message(name) .emit(); From 4af9dd176c1afdd8537d2ad49adbe05b85435a43 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Tue, 16 Dec 2025 14:33:30 -0300 Subject: [PATCH 4/5] Update recursion.rs --- crates/plotnik-lib/src/query/recursion.rs | 151 +++++++++++++--------- 1 file changed, 91 insertions(+), 60 deletions(-) diff --git a/crates/plotnik-lib/src/query/recursion.rs b/crates/plotnik-lib/src/query/recursion.rs index c0f39597..fd52844f 100644 --- a/crates/plotnik-lib/src/query/recursion.rs +++ b/crates/plotnik-lib/src/query/recursion.rs @@ -8,8 +8,9 @@ use indexmap::{IndexMap, IndexSet}; use rowan::TextRange; use super::Query; +use super::visitor::{Visitor, walk_expr}; use crate::diagnostics::DiagnosticKind; -use crate::parser::{Def, Expr}; +use crate::parser::{AnonymousNode, Def, Expr, NamedNode, Ref, SeqExpr}; impl Query<'_> { pub(super) fn validate_recursion(&mut self) { @@ -36,8 +37,8 @@ impl Query<'_> { if !has_escape { // Find a cycle to report. Any cycle within the SCC is an infinite recursion loop // because there are no escape paths. - if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |q, expr, target| { - q.find_ref_range(expr, target) + if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |_, expr, target| { + find_ref_range(expr, target) }) { let chain = self.format_chain(raw_chain, false); self.report_cycle(DiagnosticKind::RecursionNoEscape, &scc, chain); @@ -48,8 +49,8 @@ impl Query<'_> { // 2. Check for infinite loops (Guarded Recursion Analysis) // Even if there is an escape, every recursive cycle must consume input (be guarded). // We look for a cycle composed entirely of unguarded references. - if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |q, expr, target| { - q.find_unguarded_ref_range(expr, target) + if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |_, expr, target| { + find_unguarded_ref_range(expr, target) }) { let chain = self.format_chain(raw_chain, true); self.report_cycle(DiagnosticKind::DirectRecursion, &scc, chain); @@ -166,14 +167,6 @@ impl Query<'_> { .defs() .find(|d| d.name().map(|n| n.text() == name).unwrap_or(false)) } - - fn find_ref_range(&self, expr: &Expr, target: &str) -> Option { - find_ref_in_expr(expr, target) - } - - fn find_unguarded_ref_range(&self, expr: &Expr, target: &str) -> Option { - find_unguarded_ref_in_expr(expr, target) - } } struct CycleFinder<'a> { @@ -222,9 +215,6 @@ impl<'a> CycleFinder<'a> { for (target, range) in neighbors { if let Some(&start_index) = self.on_path.get(target) { // Cycle detected! - // Path: path[start_index] ... path[last] (current) - // Edges: edges[start_index] ... edges[last-1] - // Closing edge: range let mut chain = Vec::new(); for i in start_index..self.path.len() - 1 { chain.push((self.edges[i], self.path[i + 1].clone())); @@ -378,65 +368,106 @@ fn expr_guarantees_consumption(expr: &Expr) -> bool { } } +struct RefCollector<'a> { + refs: &'a mut IndexSet, +} + +impl Visitor for RefCollector<'_> { + fn visit_ref(&mut self, r: &Ref) { + if let Some(name) = r.name() { + self.refs.insert(name.text().to_string()); + } + } +} + fn collect_refs(expr: &Expr) -> IndexSet { let mut refs = IndexSet::new(); - collect_refs_into(expr, &mut refs); + let mut visitor = RefCollector { refs: &mut refs }; + visitor.visit_expr(expr); refs } -fn collect_refs_into(expr: &Expr, refs: &mut IndexSet) { - if let Expr::Ref(r) = expr - && let Some(name_token) = r.name() - { - refs.insert(name_token.text().to_string()); +struct RefFinder<'a> { + target: &'a str, + found: Option, +} + +impl Visitor for RefFinder<'_> { + fn visit_expr(&mut self, expr: &Expr) { + if self.found.is_some() { + return; + } + walk_expr(self, expr); } - for child in expr.children() { - collect_refs_into(&child, refs); + fn visit_ref(&mut self, r: &Ref) { + if self.found.is_some() { + return; + } + if let Some(name) = r.name() + && name.text() == self.target + { + self.found = Some(name.text_range()); + } } } -fn find_ref_in_expr(expr: &Expr, target: &str) -> Option { - if let Expr::Ref(r) = expr { - let name_token = r.name()?; - if name_token.text() == target { - return Some(name_token.text_range()); +fn find_ref_range(expr: &Expr, target: &str) -> Option { + let mut visitor = RefFinder { + target, + found: None, + }; + visitor.visit_expr(expr); + visitor.found +} + +struct UnguardedRefFinder<'a> { + target: &'a str, + found: Option, +} + +impl Visitor for UnguardedRefFinder<'_> { + fn visit_expr(&mut self, expr: &Expr) { + if self.found.is_some() { + return; } + walk_expr(self, expr); } - expr.children() - .iter() - .find_map(|child| find_ref_in_expr(child, target)) -} + fn visit_named_node(&mut self, _node: &NamedNode) { + // Guarded: stop recursion + } -fn find_unguarded_ref_in_expr(expr: &Expr, target: &str) -> Option { - match expr { - Expr::Ref(r) => r - .name() - .filter(|n| n.text() == target) - .map(|n| n.text_range()), - Expr::NamedNode(_) | Expr::AnonymousNode(_) => None, - Expr::AltExpr(_) => expr - .children() - .iter() - .find_map(|c| find_unguarded_ref_in_expr(c, target)), - Expr::SeqExpr(_) => { - for c in expr.children() { - if let Some(range) = find_unguarded_ref_in_expr(&c, target) { - return Some(range); - } - if expr_guarantees_consumption(&c) { - return None; - } + fn visit_anonymous_node(&mut self, _node: &AnonymousNode) { + // Guarded: stop recursion + } + + fn visit_ref(&mut self, r: &Ref) { + if let Some(name) = r.name() + && name.text() == self.target + { + self.found = Some(name.text_range()); + } + } + + fn visit_seq_expr(&mut self, seq: &SeqExpr) { + for child in seq.children() { + self.visit_expr(&child); + if self.found.is_some() { + return; + } + if expr_guarantees_consumption(&child) { + return; } - None } - Expr::QuantifiedExpr(q) => q - .inner() - .and_then(|i| find_unguarded_ref_in_expr(&i, target)), - Expr::CapturedExpr(_) | Expr::FieldExpr(_) => expr - .children() - .iter() - .find_map(|c| find_unguarded_ref_in_expr(c, target)), } } + +fn find_unguarded_ref_range(expr: &Expr, target: &str) -> Option { + let mut visitor = UnguardedRefFinder { + target, + found: None, + }; + visitor.visit_expr(expr); + visitor.found +} From 420f7ca1ea11c7fc24278fb96e80c883fb686d31 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Tue, 16 Dec 2025 15:22:27 -0300 Subject: [PATCH 5/5] Refactor Query to use visitor pattern for type and field resolution Extract edit distance and find_similar utilities to a new shared utils module Replace manual recursion with NodeTypeCollector and FieldCollector visitors for cleaner code --- crates/plotnik-lib/src/query/link.rs | 243 +++++++++----------------- crates/plotnik-lib/src/query/mod.rs | 1 + crates/plotnik-lib/src/query/utils.rs | 51 ++++++ 3 files changed, 135 insertions(+), 160 deletions(-) create mode 100644 crates/plotnik-lib/src/query/utils.rs diff --git a/crates/plotnik-lib/src/query/link.rs b/crates/plotnik-lib/src/query/link.rs index 19fb2d5b..105be07d 100644 --- a/crates/plotnik-lib/src/query/link.rs +++ b/crates/plotnik-lib/src/query/link.rs @@ -12,45 +12,11 @@ use rowan::TextRange; use crate::diagnostics::DiagnosticKind; use crate::parser::ast::{self, Expr, NamedNode}; use crate::parser::cst::{SyntaxKind, SyntaxToken}; +use crate::parser::token_src; use super::Query; - -/// Simple edit distance for fuzzy matching (Levenshtein). -fn edit_distance(a: &str, b: &str) -> usize { - let a_len = a.chars().count(); - let b_len = b.chars().count(); - - if a_len == 0 { - return b_len; - } - if b_len == 0 { - return a_len; - } - - let mut prev: Vec = (0..=b_len).collect(); - let mut curr = vec![0; b_len + 1]; - - for (i, ca) in a.chars().enumerate() { - curr[0] = i + 1; - for (j, cb) in b.chars().enumerate() { - let cost = if ca == cb { 0 } else { 1 }; - curr[j + 1] = (prev[j + 1] + 1).min(curr[j] + 1).min(prev[j] + cost); - } - std::mem::swap(&mut prev, &mut curr); - } - - prev[b_len] -} - -/// Find the best match from candidates within a reasonable edit distance. -fn find_similar<'a>(name: &str, candidates: &[&'a str], max_distance: usize) -> Option<&'a str> { - candidates - .iter() - .map(|&c| (c, edit_distance(name, c))) - .filter(|(_, d)| *d <= max_distance) - .min_by_key(|(_, d)| *d) - .map(|(c, _)| c) -} +use super::utils::find_similar; +use super::visitor::{Visitor, walk_root}; /// Check if `child` is a subtype of `supertype`, recursively handling nested supertypes. #[allow(dead_code)] @@ -160,69 +126,9 @@ impl<'a> Query<'a> { } fn resolve_node_types(&mut self, lang: &Lang) { - let defs: Vec<_> = self.ast.defs().collect(); - for def in defs { - let Some(body) = def.body() else { continue }; - self.collect_node_types(&body, lang); - } - } - - fn collect_node_types(&mut self, expr: &Expr, lang: &Lang) { - match expr { - Expr::NamedNode(node) => { - self.resolve_named_node(node, lang); - for child in node.children() { - self.collect_node_types(&child, lang); - } - } - Expr::AnonymousNode(anon) => { - if anon.is_any() { - return; - } - let Some(value_token) = anon.value() else { - return; - }; - let value = value_token.text(); - if self.node_type_ids.contains_key(value) { - return; - } - let resolved = lang.resolve_anonymous_node(value); - self.node_type_ids.insert( - &self.source[text_range_to_usize(value_token.text_range())], - resolved, - ); - if resolved.is_none() { - self.link_diagnostics - .report(DiagnosticKind::UnknownNodeType, value_token.text_range()) - .message(value) - .emit(); - } - } - Expr::AltExpr(alt) => { - for branch in alt.branches() { - let Some(body) = branch.body() else { continue }; - self.collect_node_types(&body, lang); - } - } - Expr::SeqExpr(seq) => { - for child in seq.children() { - self.collect_node_types(&child, lang); - } - } - Expr::CapturedExpr(cap) => { - let Some(inner) = cap.inner() else { return }; - self.collect_node_types(&inner, lang); - } - Expr::QuantifiedExpr(q) => { - let Some(inner) = q.inner() else { return }; - self.collect_node_types(&inner, lang); - } - Expr::FieldExpr(f) => { - let Some(value) = f.value() else { return }; - self.collect_node_types(&value, lang); - } - Expr::Ref(_) => {} - } + let root = self.ast.clone(); + let mut collector = NodeTypeCollector { query: self, lang }; + collector.visit_root(&root); } fn resolve_named_node(&mut self, node: &NamedNode, lang: &Lang) { @@ -243,10 +149,8 @@ impl<'a> Query<'a> { return; } let resolved = lang.resolve_named_node(type_name); - self.node_type_ids.insert( - &self.source[text_range_to_usize(type_token.text_range())], - resolved, - ); + self.node_type_ids + .insert(token_src(&type_token, self.source), resolved); if resolved.is_none() { let all_types = lang.all_named_node_kinds(); let max_dist = (type_name.len() / 3).clamp(2, 4); @@ -265,51 +169,9 @@ impl<'a> Query<'a> { } fn resolve_fields(&mut self, lang: &Lang) { - let defs: Vec<_> = self.ast.defs().collect(); - for def in defs { - let Some(body) = def.body() else { continue }; - self.collect_fields(&body, lang); - } - } - - fn collect_fields(&mut self, expr: &Expr, lang: &Lang) { - match expr { - Expr::NamedNode(node) => { - for child in node.children() { - self.collect_fields(&child, lang); - } - for child in node.as_cst().children() { - if let Some(neg) = ast::NegatedField::cast(child) { - self.resolve_field_by_token(neg.name(), lang); - } - } - } - Expr::AltExpr(alt) => { - for branch in alt.branches() { - let Some(body) = branch.body() else { continue }; - self.collect_fields(&body, lang); - } - } - Expr::SeqExpr(seq) => { - for child in seq.children() { - self.collect_fields(&child, lang); - } - } - Expr::CapturedExpr(cap) => { - let Some(inner) = cap.inner() else { return }; - self.collect_fields(&inner, lang); - } - Expr::QuantifiedExpr(q) => { - let Some(inner) = q.inner() else { return }; - self.collect_fields(&inner, lang); - } - Expr::FieldExpr(f) => { - self.resolve_field_by_token(f.name(), lang); - let Some(value) = f.value() else { return }; - self.collect_fields(&value, lang); - } - Expr::AnonymousNode(_) | Expr::Ref(_) => {} - } + let root = self.ast.clone(); + let mut collector = FieldCollector { query: self, lang }; + collector.visit_root(&root); } fn resolve_field_by_token(&mut self, name_token: Option, lang: &Lang) { @@ -321,10 +183,8 @@ impl<'a> Query<'a> { return; } let resolved = lang.resolve_field(field_name); - self.node_field_ids.insert( - &self.source[text_range_to_usize(name_token.text_range())], - resolved, - ); + self.node_field_ids + .insert(token_src(&name_token, self.source), resolved); if resolved.is_some() { return; } @@ -505,7 +365,7 @@ impl<'a> Query<'a> { parent_name: ctx.parent_name, parent_range: ctx.parent_range, field: Some(FieldContext { - name: &self.source[text_range_to_usize(name_token.text_range())], + name: token_src(&name_token, self.source), id: field_id, range: name_token.text_range(), }), @@ -736,7 +596,7 @@ impl<'a> Query<'a> { } let type_name = type_token.text(); let type_id = self.node_type_ids.get(type_name).copied().flatten()?; - let name = &self.source[text_range_to_usize(type_token.text_range())]; + let name = token_src(&type_token, self.source); Some((type_id, name, type_token.text_range())) } Expr::AnonymousNode(anon) => { @@ -744,7 +604,7 @@ impl<'a> Query<'a> { return None; } let value_token = anon.value()?; - let value = &self.source[text_range_to_usize(value_token.text_range())]; + let value = token_src(&value_token, self.source); let type_id = self.node_type_ids.get(value).copied().flatten()?; Some((type_id, value, value_token.text_range())) } @@ -813,8 +673,71 @@ impl<'a> Query<'a> { } } -fn text_range_to_usize(range: TextRange) -> std::ops::Range { - let start: usize = range.start().into(); - let end: usize = range.end().into(); - start..end +struct NodeTypeCollector<'a, 'q> { + query: &'a mut Query<'q>, + lang: &'a Lang, +} + +impl Visitor for NodeTypeCollector<'_, '_> { + fn visit_root(&mut self, root: &ast::Root) { + walk_root(self, root); + } + + fn visit_named_node(&mut self, node: &ast::NamedNode) { + self.query.resolve_named_node(node, self.lang); + super::visitor::walk_named_node(self, node); + } + + fn visit_anonymous_node(&mut self, node: &ast::AnonymousNode) { + if node.is_any() { + return; + } + let Some(value_token) = node.value() else { + return; + }; + let value = value_token.text(); + if self.query.node_type_ids.contains_key(value) { + return; + } + + let resolved = self.lang.resolve_anonymous_node(value); + self.query + .node_type_ids + .insert(token_src(&value_token, self.query.source), resolved); + + if resolved.is_none() { + self.query + .link_diagnostics + .report(DiagnosticKind::UnknownNodeType, value_token.text_range()) + .message(value) + .emit(); + } + } +} + +struct FieldCollector<'a, 'q> { + query: &'a mut Query<'q>, + lang: &'a Lang, +} + +impl Visitor for FieldCollector<'_, '_> { + fn visit_root(&mut self, root: &ast::Root) { + walk_root(self, root); + } + + fn visit_named_node(&mut self, node: &ast::NamedNode) { + for child in node.as_cst().children() { + if let Some(neg) = ast::NegatedField::cast(child) { + self.query.resolve_field_by_token(neg.name(), self.lang); + } + } + + super::visitor::walk_named_node(self, node); + } + + fn visit_field_expr(&mut self, field: &ast::FieldExpr) { + self.query.resolve_field_by_token(field.name(), self.lang); + + super::visitor::walk_field_expr(self, field); + } } diff --git a/crates/plotnik-lib/src/query/mod.rs b/crates/plotnik-lib/src/query/mod.rs index 001f1881..2a653b9a 100644 --- a/crates/plotnik-lib/src/query/mod.rs +++ b/crates/plotnik-lib/src/query/mod.rs @@ -11,6 +11,7 @@ mod dump; mod graph_qis; mod invariants; mod printer; +mod utils; pub use printer::QueryPrinter; pub mod alt_kinds; diff --git a/crates/plotnik-lib/src/query/utils.rs b/crates/plotnik-lib/src/query/utils.rs new file mode 100644 index 00000000..5f710235 --- /dev/null +++ b/crates/plotnik-lib/src/query/utils.rs @@ -0,0 +1,51 @@ +//! Small string utilities shared by query passes. +//! +//! This module intentionally stays minimal and dependency-free. +//! Only extract helpers here when they are used by 2+ modules or are clearly +//! pass-agnostic (formatting, suggestion, small string algorithms). + +/// Simple edit distance for fuzzy matching (Levenshtein). +/// +/// This is optimized for correctness and small inputs (identifiers, field names), +/// not for very large strings. +pub fn edit_distance(a: &str, b: &str) -> usize { + let a_len = a.chars().count(); + let b_len = b.chars().count(); + + if a_len == 0 { + return b_len; + } + if b_len == 0 { + return a_len; + } + + let mut prev: Vec = (0..=b_len).collect(); + let mut curr = vec![0; b_len + 1]; + + for (i, ca) in a.chars().enumerate() { + curr[0] = i + 1; + for (j, cb) in b.chars().enumerate() { + let cost = if ca == cb { 0 } else { 1 }; + curr[j + 1] = (prev[j + 1] + 1).min(curr[j] + 1).min(prev[j] + cost); + } + std::mem::swap(&mut prev, &mut curr); + } + + prev[b_len] +} + +/// Find the best match from candidates within a maximum edit distance. +/// +/// Returns the closest candidate (lowest distance) if it is within `max_distance`. +pub fn find_similar<'a>( + name: &str, + candidates: &[&'a str], + max_distance: usize, +) -> Option<&'a str> { + candidates + .iter() + .map(|&c| (c, edit_distance(name, c))) + .filter(|(_, d)| *d <= max_distance) + .min_by_key(|(_, d)| *d) + .map(|(c, _)| c) +}