diff --git a/crates/plotnik-lib/src/parser/cst.rs b/crates/plotnik-lib/src/parser/cst.rs index 27af3b78..0f5f6c80 100644 --- a/crates/plotnik-lib/src/parser/cst.rs +++ b/crates/plotnik-lib/src/parser/cst.rs @@ -136,7 +136,6 @@ pub enum SyntaxKind { Garbage, Error, - // --- Node kinds (non-terminals) --- Root, Tree, Ref, diff --git a/crates/plotnik-lib/src/query/dependencies.rs b/crates/plotnik-lib/src/query/dependencies.rs index 6af16937..9a24369e 100644 --- a/crates/plotnik-lib/src/query/dependencies.rs +++ b/crates/plotnik-lib/src/query/dependencies.rs @@ -21,29 +21,29 @@ use crate::query::visitor::{Visitor, walk_expr}; /// Result of dependency analysis. #[derive(Debug, Clone, Default)] -pub struct DependencyAnalysis<'q> { +pub struct DependencyAnalysis { /// Strongly connected components in reverse topological order. /// /// - `sccs[0]` has no dependencies (or depends only on things not in this list). /// - `sccs.last()` depends on everything else. /// - Definitions within an SCC are mutually recursive. /// - Every definition in the symbol table appears exactly once. - pub sccs: Vec>, + pub sccs: Vec>, } /// Analyze dependencies between definitions. /// /// Returns the SCCs in reverse topological order. -pub fn analyze_dependencies<'q>(symbol_table: &SymbolTable<'q>) -> DependencyAnalysis<'q> { +pub fn analyze_dependencies(symbol_table: &SymbolTable) -> DependencyAnalysis { let sccs = SccFinder::find(symbol_table); DependencyAnalysis { sccs } } /// Validate recursion using the pre-computed dependency analysis. -pub fn validate_recursion<'q>( - analysis: &DependencyAnalysis<'q>, +pub fn validate_recursion( + analysis: &DependencyAnalysis, ast_map: &IndexMap, - symbol_table: &SymbolTable<'q>, + symbol_table: &SymbolTable, diag: &mut Diagnostics, ) { let mut validator = RecursionValidator { @@ -54,50 +54,41 @@ pub fn validate_recursion<'q>( validator.validate(&analysis.sccs); } -// ----------------------------------------------------------------------------- -// Recursion Validator -// ----------------------------------------------------------------------------- - -struct RecursionValidator<'a, 'q, 'd> { +struct RecursionValidator<'a, 'd> { ast_map: &'a IndexMap, - symbol_table: &'a SymbolTable<'q>, + symbol_table: &'a SymbolTable, diag: &'d mut Diagnostics, } -impl<'a, 'q, 'd> RecursionValidator<'a, 'q, 'd> { - fn validate(&mut self, sccs: &[Vec<&'q str>]) { +impl<'a, 'd> RecursionValidator<'a, 'd> { + fn validate(&mut self, sccs: &[Vec]) { for scc in sccs { self.validate_scc(scc); } } - fn validate_scc(&mut self, scc: &[&'q str]) { + fn validate_scc(&mut self, scc: &[String]) { // Filter out trivial non-recursive components. // A component is recursive if it has >1 node, or 1 node that references itself. if scc.len() == 1 { - let name = scc[0]; - let is_self_recursive = self - .symbol_table - .get(name) - .map(|(_, body)| collect_refs(body, self.symbol_table).contains(name)) - .unwrap_or(false); - - if !is_self_recursive { + let name = &scc[0]; + let Some(body) = self.symbol_table.get(name) else { + return; + }; + if !collect_refs(body, self.symbol_table).contains(name.as_str()) { return; } } - let scc_set: IndexSet<&'q str> = scc.iter().copied().collect(); + let scc_set: IndexSet<&str> = scc.iter().map(String::as_str).collect(); // 1. Check for infinite tree structure (Escape Analysis) // A valid recursive definition must have a non-recursive path. // If NO definition in the SCC has an escape path, the whole group is invalid. - let has_escape = scc.iter().any(|name| { - self.symbol_table - .get(*name) - .map(|(_, body)| expr_has_escape(body, &scc_set)) - .unwrap_or(true) - }); + let has_escape = scc + .iter() + .filter_map(|name| self.symbol_table.get(name)) + .any(|body| expr_has_escape(body, &scc_set)); if !has_escape { // Find a cycle to report. Any cycle within the SCC is an infinite recursion loop @@ -124,15 +115,15 @@ impl<'a, 'q, 'd> RecursionValidator<'a, 'q, 'd> { /// Finds a cycle within the given set of nodes (SCC). /// `get_edge_location` returns the location of a reference from `expr` to `target`. - fn find_cycle( + fn find_cycle<'s>( &self, - nodes: &[&'q str], - domain: &IndexSet<&'q str>, + nodes: &'s [String], + domain: &IndexSet<&'s str>, get_edge_location: impl Fn(&Self, SourceId, &Expr, &str) -> Option, - ) -> Option> { + ) -> Option> { let mut adj = IndexMap::new(); for name in nodes { - if let Some(&(source_id, ref body)) = self.symbol_table.get(*name) { + if let Some((source_id, body)) = self.symbol_table.get_full(name) { let neighbors = domain .iter() .filter_map(|target| { @@ -140,20 +131,21 @@ impl<'a, 'q, 'd> RecursionValidator<'a, 'q, 'd> { .map(|range| (*target, source_id, range)) }) .collect::>(); - adj.insert(*name, neighbors); + adj.insert(name.as_str(), neighbors); } } - CycleFinder::find(nodes, &adj) + let node_strs: Vec<&str> = nodes.iter().map(String::as_str).collect(); + CycleFinder::find(&node_strs, &adj) } fn format_chain( &self, - chain: Vec<(SourceId, TextRange, &'q str)>, + raw_chain: Vec<(SourceId, TextRange, &str)>, is_unguarded: bool, ) -> Vec<(SourceId, TextRange, String)> { - if chain.len() == 1 { - let (source_id, range, target) = &chain[0]; + if raw_chain.len() == 1 { + let (source_id, range, target) = &raw_chain[0]; let msg = if is_unguarded { "references itself".to_string() } else { @@ -162,8 +154,8 @@ impl<'a, 'q, 'd> RecursionValidator<'a, 'q, 'd> { return vec![(*source_id, *range, msg)]; } - let len = chain.len(); - chain + let len = raw_chain.len(); + raw_chain .into_iter() .enumerate() .map(|(i, (source_id, range, target))| { @@ -180,7 +172,7 @@ impl<'a, 'q, 'd> RecursionValidator<'a, 'q, 'd> { fn report_cycle( &mut self, kind: DiagnosticKind, - scc: &[&'q str], + scc: &[String], chain: Vec<(SourceId, TextRange, String)>, ) { let (primary_source, primary_loc) = chain @@ -209,27 +201,21 @@ impl<'a, 'q, 'd> RecursionValidator<'a, 'q, 'd> { fn find_def_info_containing( &self, - scc: &[&'q str], + scc: &[String], range: TextRange, ) -> Option<(SourceId, String, TextRange)> { - scc.iter() - .find(|name| { - self.symbol_table - .get(*name) - .map(|(_, body)| body.text_range().contains_range(range)) - .unwrap_or(false) - }) - .and_then(|name| { - self.find_def_by_name(name).and_then(|(source_id, def)| { - def.name().map(|n| { - ( - source_id, - format!("{} is defined here", name), - n.text_range(), - ) - }) - }) - }) + let name = scc.iter().find(|name| { + self.symbol_table + .get(name.as_str()) + .is_some_and(|body| body.text_range().contains_range(range)) + })?; + let (source_id, def) = self.find_def_by_name(name)?; + let n = def.name()?; + Some(( + source_id, + format!("{} is defined here", name), + n.text_range(), + )) } fn find_def_by_name(&self, name: &str) -> Option<(SourceId, Def)> { @@ -241,22 +227,18 @@ impl<'a, 'q, 'd> RecursionValidator<'a, 'q, 'd> { } } -// ----------------------------------------------------------------------------- -// SCC Finder (Tarjan's Algorithm) -// ----------------------------------------------------------------------------- - -struct SccFinder<'a, 'q> { - symbol_table: &'a SymbolTable<'q>, +struct SccFinder<'a> { + symbol_table: &'a SymbolTable, index: usize, - stack: Vec<&'q str>, - on_stack: IndexSet<&'q str>, - indices: IndexMap<&'q str, usize>, - lowlinks: IndexMap<&'q str, usize>, - sccs: Vec>, + stack: Vec<&'a str>, + on_stack: IndexSet<&'a str>, + indices: IndexMap<&'a str, usize>, + lowlinks: IndexMap<&'a str, usize>, + sccs: Vec>, } -impl<'a, 'q> SccFinder<'a, 'q> { - fn find(symbol_table: &'a SymbolTable<'q>) -> Vec> { +impl<'a> SccFinder<'a> { + fn find(symbol_table: &'a SymbolTable) -> Vec> { let mut finder = Self { symbol_table, index: 0, @@ -267,27 +249,29 @@ impl<'a, 'q> SccFinder<'a, 'q> { sccs: Vec::new(), }; - for &name in symbol_table.keys() { - if !finder.indices.contains_key(name) { + for name in symbol_table.keys() { + if !finder.indices.contains_key(name as &str) { finder.strongconnect(name); } } - finder.sccs + finder + .sccs + .into_iter() + .map(|scc| scc.into_iter().map(String::from).collect()) + .collect() } - fn strongconnect(&mut self, name: &'q str) { + fn strongconnect(&mut self, name: &'a str) { self.indices.insert(name, self.index); self.lowlinks.insert(name, self.index); self.index += 1; self.stack.push(name); self.on_stack.insert(name); - if let Some((_, body)) = self.symbol_table.get(name) { + if let Some(body) = self.symbol_table.get(name) { let refs = collect_refs(body, self.symbol_table); for ref_name in refs { - // We've already resolved to canonical &'q str in collect_refs - // so we can use it directly. if !self.indices.contains_key(ref_name) { self.strongconnect(ref_name); let ref_lowlink = self.lowlinks[ref_name]; @@ -305,9 +289,10 @@ impl<'a, 'q> SccFinder<'a, 'q> { let mut scc = Vec::new(); loop { let w = self.stack.pop().unwrap(); - self.on_stack.swap_remove(w); + self.on_stack.swap_remove(&w); + let done = w == name; scc.push(w); - if w == name { + if done { break; } } @@ -316,10 +301,6 @@ impl<'a, 'q> SccFinder<'a, 'q> { } } -// ----------------------------------------------------------------------------- -// Cycle Finder -// ----------------------------------------------------------------------------- - struct CycleFinder<'a, 'q> { adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>, visited: IndexSet<&'q str>, @@ -389,35 +370,38 @@ impl<'a, 'q> CycleFinder<'a, 'q> { } } -// ----------------------------------------------------------------------------- -// Helper Visitors -// ----------------------------------------------------------------------------- - -fn expr_has_escape(expr: &Expr, scc: &IndexSet<&str>) -> bool { +fn expr_has_escape(expr: &Expr, scc_names: &IndexSet<&str>) -> bool { match expr { Expr::Ref(r) => { let Some(name_token) = r.name() else { return true; }; - !scc.contains(name_token.text()) + !scc_names.contains(name_token.text()) } Expr::NamedNode(node) => { let children: Vec<_> = node.children().collect(); - children.is_empty() || children.iter().all(|c| expr_has_escape(c, scc)) + children.is_empty() || children.iter().all(|c| expr_has_escape(c, scc_names)) } - Expr::AltExpr(_) => expr.children().iter().any(|c| expr_has_escape(c, scc)), - Expr::SeqExpr(_) => expr.children().iter().all(|c| expr_has_escape(c, scc)), + Expr::AltExpr(_) => expr + .children() + .iter() + .any(|c| expr_has_escape(c, scc_names)), + Expr::SeqExpr(_) => expr + .children() + .iter() + .all(|c| expr_has_escape(c, scc_names)), Expr::QuantifiedExpr(q) => { if q.is_optional() { return true; } q.inner() - .map(|inner| expr_has_escape(&inner, scc)) + .map(|inner| expr_has_escape(&inner, scc_names)) .unwrap_or(true) } - Expr::CapturedExpr(_) | Expr::FieldExpr(_) => { - expr.children().iter().all(|c| expr_has_escape(c, scc)) - } + Expr::CapturedExpr(_) | Expr::FieldExpr(_) => expr + .children() + .iter() + .all(|c| expr_has_escape(c, scc_names)), Expr::AnonymousNode(_) => true, } } @@ -440,29 +424,18 @@ fn expr_guarantees_consumption(expr: &Expr) -> bool { } } -struct RefCollector<'a, 'q> { - symbol_table: &'a SymbolTable<'q>, - refs: &'a mut IndexSet<&'q str>, -} - -impl<'a, 'q> Visitor for RefCollector<'a, 'q> { - fn visit_ref(&mut self, r: &Ref) { - if let Some(name) = r.name() { - // We immediately resolve to canonical &'q str keys to avoid allocations - if let Some((&k, _)) = self.symbol_table.get_key_value(name.text()) { - self.refs.insert(k); - } - } - } -} - -fn collect_refs<'q>(expr: &Expr, symbol_table: &SymbolTable<'q>) -> IndexSet<&'q str> { +fn collect_refs<'a>(expr: &Expr, symbol_table: &'a SymbolTable) -> IndexSet<&'a str> { let mut refs = IndexSet::new(); - let mut visitor = RefCollector { - symbol_table, - refs: &mut refs, - }; - visitor.visit_expr(expr); + for descendant in expr.as_cst().descendants() { + let Some(r) = Ref::cast(descendant) else { + continue; + }; + let Some(name_tok) = r.name() else { continue }; + let Some(key) = symbol_table.keys().find(|&k| k == name_tok.text()) else { + continue; + }; + refs.insert(key); + } refs } diff --git a/crates/plotnik-lib/src/query/expr_arity.rs b/crates/plotnik-lib/src/query/expr_arity.rs index d6452f38..cc6a4034 100644 --- a/crates/plotnik-lib/src/query/expr_arity.rs +++ b/crates/plotnik-lib/src/query/expr_arity.rs @@ -92,7 +92,7 @@ pub fn resolve_arity(node: &SyntaxNode, table: &ExprArityTable) -> Option { - symbol_table: &'a SymbolTable<'a>, + symbol_table: &'a SymbolTable, arity_table: ExprArityTable, diag: &'d mut Diagnostics, source_id: SourceId, @@ -158,7 +158,7 @@ impl ArityContext<'_, '_> { self.symbol_table .get(name) - .map(|(_, body)| self.compute_arity(body)) + .map(|body| self.compute_arity(body)) .unwrap_or(ExprArity::Invalid) } @@ -191,9 +191,9 @@ impl ArityContext<'_, '_> { // If value is a reference, add related info pointing to definition if let Expr::Ref(r) = &value && let Some(name_tok) = r.name() - && let Some((def_source, def_body)) = self.symbol_table.get(name_tok.text()) + && let Some((def_source, def_body)) = self.symbol_table.get_full(name_tok.text()) { - builder = builder.related_to(*def_source, def_body.text_range(), "defined here"); + builder = builder.related_to(def_source, def_body.text_range(), "defined here"); } builder.emit(); diff --git a/crates/plotnik-lib/src/query/link.rs b/crates/plotnik-lib/src/query/link.rs index d82e5e4a..cdd8358a 100644 --- a/crates/plotnik-lib/src/query/link.rs +++ b/crates/plotnik-lib/src/query/link.rs @@ -19,7 +19,7 @@ use crate::parser::token_src; use super::query::AstMap; use super::source_map::{SourceId, SourceMap}; -use super::symbol_table::SymbolTableOwned; +use super::symbol_table::SymbolTable; use super::utils::find_similar; use super::visitor::{Visitor, walk}; @@ -31,7 +31,7 @@ pub fn link<'q>( ast_map: &AstMap, source_map: &'q SourceMap, lang: &Lang, - symbol_table: &SymbolTableOwned, + symbol_table: &SymbolTable, node_type_ids: &mut HashMap<&'q str, Option>, node_field_ids: &mut HashMap<&'q str, Option>, diagnostics: &mut Diagnostics, @@ -54,7 +54,7 @@ struct Linker<'a, 'q> { source_map: &'q SourceMap, source_id: SourceId, lang: &'a Lang, - symbol_table: &'a SymbolTableOwned, + symbol_table: &'a SymbolTable, node_type_ids: &'a mut HashMap<&'q str, Option>, node_field_ids: &'a mut HashMap<&'q str, Option>, diagnostics: &'a mut Diagnostics, @@ -167,36 +167,21 @@ impl<'a, 'q> Linker<'a, 'q> { fn validate_expr_structure( &mut self, expr: &Expr, - ctx: Option>, + ctx: Option, visited: &mut IndexSet, ) { match expr { Expr::NamedNode(node) => { - // Validate this node against the context (if any) - if let Some(ref ctx) = ctx { - self.validate_terminal_type(expr, ctx, visited); - } - - // Set up context for children let child_ctx = self.make_node_context(node); for child in node.children() { - match &child { - Expr::FieldExpr(f) => { - // Fields get special handling - self.validate_field_expr(f, child_ctx.as_ref(), visited); - } - _ => { - // Non-field children: validate as non-field children - if let Some(ref ctx) = child_ctx { - self.validate_non_field_children(&child, ctx, visited); - } - self.validate_expr_structure(&child, child_ctx, visited); - } + if let Expr::FieldExpr(f) = &child { + self.validate_field_expr(f, child_ctx.as_ref(), visited); + } else { + self.validate_expr_structure(&child, child_ctx, visited); } } - // Handle negated fields if let Some(ctx) = child_ctx { for child in node.as_cst().children() { if let Some(neg) = ast::NegatedField::cast(child) { @@ -205,12 +190,7 @@ impl<'a, 'q> Linker<'a, 'q> { } } } - Expr::AnonymousNode(_) => { - // Validate this anonymous node against the context (if any) - if let Some(ref ctx) = ctx { - self.validate_terminal_type(expr, ctx, visited); - } - } + Expr::AnonymousNode(_) => {} Expr::FieldExpr(f) => { // Should be handled by parent NamedNode, but handle gracefully self.validate_field_expr(f, ctx.as_ref(), visited); @@ -240,7 +220,7 @@ impl<'a, 'q> Linker<'a, 'q> { if !visited.insert(name.to_string()) { return; } - let Some((_, body)) = self.symbol_table.get(name).cloned() else { + let Some(body) = self.symbol_table.get(name).cloned() else { visited.swap_remove(name); return; }; @@ -251,7 +231,7 @@ impl<'a, 'q> Linker<'a, 'q> { } /// Create validation context for a named node's children. - fn make_node_context(&self, node: &NamedNode) -> Option> { + fn make_node_context(&self, node: &NamedNode) -> Option { if node.is_any() { return None; } @@ -264,78 +244,48 @@ impl<'a, 'q> Linker<'a, 'q> { } let type_name = type_token.text(); let parent_id = self.node_type_ids.get(type_name).copied().flatten()?; - let parent_name = self.lang.node_type_name(parent_id)?; + // Verify the node type exists in the grammar + self.lang.node_type_name(parent_id)?; Some(ValidationContext { parent_id, - parent_name, parent_range: type_token.text_range(), }) } - /// Validate a field expression. fn validate_field_expr( &mut self, field: &ast::FieldExpr, - ctx: Option<&ValidationContext<'a>>, + ctx: Option<&ValidationContext>, visited: &mut IndexSet, ) { let Some(name_token) = field.name() else { return; }; - let field_name = name_token.text(); - - let Some(field_id) = self.node_field_ids.get(field_name).copied().flatten() else { - return; - }; - - let Some(ctx) = ctx else { + let Some(field_id) = self + .node_field_ids + .get(name_token.text()) + .copied() + .flatten() + else { return; }; + let Some(ctx) = ctx else { return }; - // Check field exists on parent if !self.lang.has_field(ctx.parent_id, field_id) { self.emit_field_not_on_node( name_token.text_range(), - field_name, + name_token.text(), ctx.parent_id, ctx.parent_range, ); return; } - let Some(value) = field.value() else { - return; - }; - - // Create context for validating the value - let field_ctx = ValidationContext { - parent_id: ctx.parent_id, - parent_name: ctx.parent_name, - parent_range: ctx.parent_range, - }; - - // Validate field value - this will traverse through alt/seq/quantifier/capture - // and validate each terminal type against the field requirements - self.validate_expr_structure(&value, Some(field_ctx), visited); - } - - fn validate_non_field_children( - &mut self, - _expr: &Expr, - _ctx: &ValidationContext<'a>, - _visited: &mut IndexSet, - ) { - } - - fn validate_terminal_type( - &mut self, - _expr: &Expr, - _ctx: &ValidationContext<'a>, - _visited: &mut IndexSet, - ) { + let Some(value) = field.value() else { return }; + self.validate_expr_structure(&value, Some(*ctx), visited); } - fn validate_negated_field(&mut self, neg: &ast::NegatedField, ctx: &ValidationContext<'a>) { + fn validate_negated_field(&mut self, neg: &ast::NegatedField, ctx: &ValidationContext) { let Some(name_token) = neg.name() else { return; }; @@ -419,11 +369,9 @@ fn format_list(items: &[&str], max_items: usize) -> String { /// Context for validating child types. #[derive(Clone, Copy)] -struct ValidationContext<'a> { +struct ValidationContext { /// The parent node type being validated against. parent_id: NodeTypeId, - /// The parent node's name for error messages. - parent_name: &'a str, /// The parent node type token range for related_to. parent_range: TextRange, } diff --git a/crates/plotnik-lib/src/query/mod.rs b/crates/plotnik-lib/src/query/mod.rs index 5de9913c..78d4944e 100644 --- a/crates/plotnik-lib/src/query/mod.rs +++ b/crates/plotnik-lib/src/query/mod.rs @@ -6,6 +6,7 @@ mod utils; pub use printer::QueryPrinter; pub use query::{Query, QueryBuilder}; pub use source_map::{SourceId, SourceMap}; +pub use symbol_table::SymbolTable; pub mod alt_kinds; mod dependencies; diff --git a/crates/plotnik-lib/src/query/printer.rs b/crates/plotnik-lib/src/query/printer.rs index d6ef4e0e..333abc5f 100644 --- a/crates/plotnik-lib/src/query/printer.rs +++ b/crates/plotnik-lib/src/query/printer.rs @@ -113,7 +113,7 @@ impl<'q> QueryPrinter<'q> { return Ok(()); } - let defined: IndexSet<&str> = symbols.keys().map(String::as_str).collect(); + let defined: IndexSet<&str> = symbols.keys().collect(); // Collect body nodes from all files let mut body_nodes: HashMap = HashMap::new(); @@ -161,7 +161,7 @@ impl<'q> QueryPrinter<'q> { writeln!(w, "{}{}{}", prefix, name, card)?; visited.insert(name.to_string()); - if let Some((_, body)) = self.query.symbol_table.get(name) { + if let Some(body) = self.query.symbol_table.get(name) { let refs_set = collect_refs(body); let mut refs: Vec<_> = refs_set.iter().map(|s| s.as_str()).collect(); refs.sort(); diff --git a/crates/plotnik-lib/src/query/query.rs b/crates/plotnik-lib/src/query/query.rs index 780b6920..30da1c61 100644 --- a/crates/plotnik-lib/src/query/query.rs +++ b/crates/plotnik-lib/src/query/query.rs @@ -14,7 +14,7 @@ use crate::query::dependencies; use crate::query::expr_arity::{ExprArity, ExprArityTable, infer_arities, resolve_arity}; use crate::query::link; use crate::query::source_map::{SourceId, SourceMap}; -use crate::query::symbol_table::{SymbolTableOwned, resolve_names}; +use crate::query::symbol_table::{SymbolTable, resolve_names}; const DEFAULT_QUERY_PARSE_FUEL: u32 = 1_000_000; const DEFAULT_QUERY_PARSE_MAX_DEPTH: u32 = 4096; @@ -116,11 +116,10 @@ impl QueryParsed { ); let arity_table = infer_arities(&self.ast_map, &symbol_table, &mut self.diag); - let symbol_table_owned = crate::query::symbol_table::to_owned(symbol_table); QueryAnalyzed { query_parsed: self, - symbol_table: symbol_table_owned, + symbol_table, arity_table, } } @@ -142,7 +141,7 @@ pub type Query = QueryAnalyzed; pub struct QueryAnalyzed { query_parsed: QueryParsed, - pub symbol_table: SymbolTableOwned, + pub symbol_table: SymbolTable, arity_table: ExprArityTable, } diff --git a/crates/plotnik-lib/src/query/symbol_table.rs b/crates/plotnik-lib/src/query/symbol_table.rs index b945f71a..617b3c0a 100644 --- a/crates/plotnik-lib/src/query/symbol_table.rs +++ b/crates/plotnik-lib/src/query/symbol_table.rs @@ -17,19 +17,98 @@ use super::visitor::Visitor; /// Code generators can emit whatever name they want for this. pub const UNNAMED_DEF: &str = "_"; -pub type SymbolTable<'src> = IndexMap<&'src str, (SourceId, ast::Expr)>; -pub type SymbolTableOwned = IndexMap; +/// Registry of named definitions in a query. +/// +/// Stores the mapping from definition names to their AST expressions, +/// along with source file information for diagnostics. +#[derive(Debug, Clone, Default)] +pub struct SymbolTable { + /// Maps symbol name to its AST expression. + table: IndexMap, + /// Maps symbol name to the source file where it's defined. + files: IndexMap, +} + +impl SymbolTable { + pub fn new() -> Self { + Self::default() + } + + /// Insert a symbol definition. + /// + /// Returns `true` if the symbol was newly inserted, `false` if it already existed + /// (in which case the old value is replaced). + pub fn insert(&mut self, name: &str, source_id: SourceId, expr: ast::Expr) -> bool { + let is_new = !self.table.contains_key(name); + self.table.insert(name.to_owned(), expr); + self.files.insert(name.to_owned(), source_id); + is_new + } + + /// Remove a symbol definition. + pub fn remove(&mut self, name: &str) -> Option<(SourceId, ast::Expr)> { + let expr = self.table.shift_remove(name)?; + let source_id = self.files.shift_remove(name)?; + Some((source_id, expr)) + } -pub fn to_owned(table: SymbolTable<'_>) -> SymbolTableOwned { - table.into_iter().map(|(k, v)| (k.to_owned(), v)).collect() + /// Check if a symbol is defined. + pub fn contains(&self, name: &str) -> bool { + self.table.contains_key(name) + } + + /// Get the expression for a symbol. + pub fn get(&self, name: &str) -> Option<&ast::Expr> { + self.table.get(name) + } + + /// Get the source file where a symbol is defined. + pub fn source_id(&self, name: &str) -> Option { + self.files.get(name).copied() + } + + /// Get both the source ID and expression for a symbol. + pub fn get_full(&self, name: &str) -> Option<(SourceId, &ast::Expr)> { + let expr = self.table.get(name)?; + let source_id = self.files.get(name).copied()?; + Some((source_id, expr)) + } + + /// Number of symbols in the symbol table. + pub fn len(&self) -> usize { + self.table.len() + } + + /// Check if the symbol table is empty. + pub fn is_empty(&self) -> bool { + self.table.is_empty() + } + + /// Iterate over symbol names in insertion order. + pub fn keys(&self) -> impl Iterator { + self.table.keys().map(String::as_str) + } + + /// Iterate over (name, expr) pairs in insertion order. + pub fn iter(&self) -> impl Iterator { + self.table.iter().map(|(k, v)| (k.as_str(), v)) + } + + /// Iterate over (name, source_id, expr) tuples in insertion order. + pub fn iter_full(&self) -> impl Iterator { + self.table.iter().map(|(k, v)| { + let source_id = self.files[k]; + (k.as_str(), source_id, v) + }) + } } -pub fn resolve_names<'q>( - source_map: &'q SourceMap, +pub fn resolve_names( + source_map: &SourceMap, ast_map: &IndexMap, diag: &mut Diagnostics, -) -> SymbolTable<'q> { - let mut symbol_table = SymbolTable::default(); +) -> SymbolTable { + let mut symbol_table = SymbolTable::new(); // Pass 1: collect definitions from all sources for (&source_id, ast) in ast_map { @@ -45,9 +124,7 @@ pub fn resolve_names<'q>( // Pass 2: validate references from all sources for (&source_id, ast) in ast_map { - let src = source_map.content(source_id); let mut validator = ReferenceValidator { - src, source_id, diag, symbol_table: &symbol_table, @@ -62,7 +139,7 @@ struct ReferenceResolver<'q, 'd, 't> { src: &'q str, source_id: SourceId, diag: &'d mut Diagnostics, - symbol_table: &'t mut SymbolTable<'q>, + symbol_table: &'t mut SymbolTable, } impl Visitor for ReferenceResolver<'_, '_, '_> { @@ -72,7 +149,7 @@ impl Visitor for ReferenceResolver<'_, '_, '_> { if let Some(token) = def.name() { // Named definition: `Name = ...` let name = token_src(&token, self.src); - if self.symbol_table.contains_key(name) { + if self.symbol_table.contains(name) { self.diag .report( self.source_id, @@ -82,34 +159,31 @@ impl Visitor for ReferenceResolver<'_, '_, '_> { .message(name) .emit(); } else { - self.symbol_table.insert(name, (self.source_id, body)); + self.symbol_table.insert(name, self.source_id, body); } } else { // Unnamed definition: `...` (root expression) // Parser already validates multiple unnamed defs; we keep the last one. - if self.symbol_table.contains_key(UNNAMED_DEF) { - self.symbol_table.shift_remove(UNNAMED_DEF); + if self.symbol_table.contains(UNNAMED_DEF) { + self.symbol_table.remove(UNNAMED_DEF); } - self.symbol_table - .insert(UNNAMED_DEF, (self.source_id, body)); + self.symbol_table.insert(UNNAMED_DEF, self.source_id, body); } } } -struct ReferenceValidator<'q, 'd, 't> { - #[allow(dead_code)] - src: &'q str, +struct ReferenceValidator<'d, 't> { source_id: SourceId, diag: &'d mut Diagnostics, - symbol_table: &'t SymbolTable<'q>, + symbol_table: &'t SymbolTable, } -impl Visitor for ReferenceValidator<'_, '_, '_> { +impl Visitor for ReferenceValidator<'_, '_> { fn visit_ref(&mut self, r: &ast::Ref) { let Some(name_token) = r.name() else { return }; let name = name_token.text(); - if self.symbol_table.contains_key(name) { + if self.symbol_table.contains(name) { return; }