diff --git a/crates/plotnik-lib/src/analyze/dependencies.rs b/crates/plotnik-lib/src/analyze/dependencies.rs index 2437bb67..22a5146d 100644 --- a/crates/plotnik-lib/src/analyze/dependencies.rs +++ b/crates/plotnik-lib/src/analyze/dependencies.rs @@ -1,26 +1,19 @@ -//! Dependency analysis and recursion validation. +//! Dependency analysis for definitions. //! -//! This module computes the dependency graph of definitions, identifies -//! Strongly Connected Components (SCCs), and validates that recursive -//! definitions are well-formed (guarded and escapable). -//! -//! The computed SCCs are exposed in reverse topological order (leaves first), -//! which is useful for passes that need to process dependencies before -//! dependents (like type inference). +//! Computes the dependency graph of definitions and identifies Strongly Connected +//! Components (SCCs). The computed SCCs are exposed in reverse topological order +//! (leaves first), which is useful for passes that need to process dependencies +//! before dependents (like type inference). use std::collections::{HashMap, HashSet}; use indexmap::{IndexMap, IndexSet}; use plotnik_core::{Interner, Symbol}; -use rowan::TextRange; use super::symbol_table::SymbolTable; use super::type_check::DefId; -use super::visitor::{Visitor, walk_expr}; -use crate::Diagnostics; -use crate::diagnostics::DiagnosticKind; -use crate::parser::{AnonymousNode, Def, Expr, NamedNode, Ref, Root, SeqExpr}; -use crate::query::source_map::SourceId; +use crate::parser::Ref; +use crate::parser::ast::Expr; /// Result of dependency analysis. #[derive(Clone, Debug, Default)] @@ -140,195 +133,6 @@ pub fn analyze_dependencies( } } -/// Validate recursion using the pre-computed dependency analysis. -pub fn validate_recursion( - analysis: &DependencyAnalysis, - ast_map: &IndexMap, - symbol_table: &SymbolTable, - diag: &mut Diagnostics, -) { - let mut validator = RecursionValidator { - ast_map, - symbol_table, - diag, - }; - validator.validate(&analysis.sccs); -} - -struct RecursionValidator<'a, 'd> { - ast_map: &'a IndexMap, - symbol_table: &'a SymbolTable, - diag: &'d mut Diagnostics, -} - -impl<'a, 'd> RecursionValidator<'a, 'd> { - fn validate(&mut self, sccs: &[Vec]) { - for scc in sccs { - self.validate_scc(scc); - } - } - - fn validate_scc(&mut self, scc: &[String]) { - // Filter out trivial non-recursive components. - // A component is recursive if it has >1 node, or 1 node that references itself. - if scc.len() == 1 { - let name = &scc[0]; - let body = self - .symbol_table - .get(name) - .expect("node in SCC must exist in symbol table"); - if !collect_refs(body, self.symbol_table).contains(name.as_str()) { - return; - } - } - - let scc_set: IndexSet<&str> = scc.iter().map(String::as_str).collect(); - - // 1. Check for infinite tree structure (Escape Analysis) - // A valid recursive definition must have a non-recursive path. - // If NO definition in the SCC has an escape path, the whole group is invalid. - let has_escape = scc - .iter() - .filter_map(|name| self.symbol_table.get(name)) - .any(|body| expr_has_escape(body, &scc_set)); - - if !has_escape { - // Find a cycle to report. Any cycle within the SCC is an infinite recursion loop - // because there are no escape paths. - if let Some(raw_chain) = self.find_cycle(scc, &scc_set, |_, _, expr, target| { - find_ref_range(expr, target) - }) { - let chain = self.format_chain(raw_chain, false); - self.report_cycle(DiagnosticKind::RecursionNoEscape, scc, chain); - } - return; - } - - // 2. Check for infinite loops (Guarded Recursion Analysis) - // Even if there is an escape, every recursive cycle must consume input (be guarded). - // We look for a cycle composed entirely of unguarded references. - if let Some(raw_chain) = self.find_cycle(scc, &scc_set, |_, _, expr, target| { - find_unguarded_ref_range(expr, target) - }) { - let chain = self.format_chain(raw_chain, true); - self.report_cycle(DiagnosticKind::DirectRecursion, scc, chain); - } - } - - /// Finds a cycle within the given set of nodes (SCC). - /// `get_edge_location` returns the location of a reference from `expr` to `target`. - fn find_cycle<'s>( - &self, - nodes: &'s [String], - domain: &IndexSet<&'s str>, - get_edge_location: impl Fn(&Self, SourceId, &Expr, &str) -> Option, - ) -> Option> { - let mut adj = IndexMap::new(); - for name in nodes { - if let Some((source_id, body)) = self.symbol_table.get_full(name) { - let neighbors = domain - .iter() - .filter_map(|target| { - get_edge_location(self, source_id, body, target) - .map(|range| (*target, source_id, range)) - }) - .collect::>(); - adj.insert(name.as_str(), neighbors); - } - } - - let node_strs: Vec<&str> = nodes.iter().map(String::as_str).collect(); - CycleFinder::find(&node_strs, &adj) - } - - fn format_chain( - &self, - raw_chain: Vec<(SourceId, TextRange, &str)>, - is_unguarded: bool, - ) -> Vec<(SourceId, TextRange, String)> { - if raw_chain.len() == 1 { - let (source_id, range, target) = &raw_chain[0]; - let msg = if is_unguarded { - "references itself".to_string() - } else { - format!("{} references itself", target) - }; - return vec![(*source_id, *range, msg)]; - } - - let len = raw_chain.len(); - raw_chain - .into_iter() - .enumerate() - .map(|(i, (source_id, range, target))| { - let msg = if i == len - 1 { - format!("references {} (completing cycle)", target) - } else { - format!("references {}", target) - }; - (source_id, range, msg) - }) - .collect() - } - - fn report_cycle( - &mut self, - kind: DiagnosticKind, - scc: &[String], - chain: Vec<(SourceId, TextRange, String)>, - ) { - let (primary_source, primary_loc) = chain - .first() - .map(|(s, r, _)| (*s, *r)) - .unwrap_or_else(|| (SourceId::default(), TextRange::empty(0.into()))); - - let related_def = if scc.len() > 1 { - self.find_def_info_containing(scc, primary_loc) - } else { - None - }; - - let mut builder = self.diag.report(primary_source, kind, primary_loc); - - for (source_id, range, msg) in chain { - builder = builder.related_to(source_id, range, msg); - } - - if let Some((source_id, msg, range)) = related_def { - builder = builder.related_to(source_id, range, msg); - } - - builder.emit(); - } - - fn find_def_info_containing( - &self, - scc: &[String], - range: TextRange, - ) -> Option<(SourceId, String, TextRange)> { - let name = scc.iter().find(|name| { - self.symbol_table - .get(name.as_str()) - .is_some_and(|body| body.text_range().contains_range(range)) - })?; - let (source_id, def) = self.find_def_by_name(name)?; - let n = def.name()?; - Some(( - source_id, - format!("{} is defined here", name), - n.text_range(), - )) - } - - fn find_def_by_name(&self, name: &str) -> Option<(SourceId, Def)> { - self.ast_map.iter().find_map(|(source_id, ast)| { - ast.defs() - .find(|d| d.name().map(|n| n.text() == name).unwrap_or(false)) - .map(|def| (*source_id, def)) - }) - } -} - struct SccFinder<'a> { symbol_table: &'a SymbolTable, index: usize, @@ -403,130 +207,10 @@ impl<'a> SccFinder<'a> { } } -struct CycleFinder<'a, 'q> { - adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>, - visited: IndexSet<&'q str>, - on_path: IndexMap<&'q str, usize>, - path: Vec<&'q str>, - edges: Vec<(SourceId, TextRange)>, -} - -impl<'a, 'q> CycleFinder<'a, 'q> { - fn find( - nodes: &[&'q str], - adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>, - ) -> Option> { - let mut finder = Self { - adj, - visited: IndexSet::new(), - on_path: IndexMap::new(), - path: Vec::new(), - edges: Vec::new(), - }; - - for start in nodes { - if let Some(chain) = finder.dfs(start) { - return Some(chain); - } - } - None - } - - fn dfs(&mut self, current: &'q str) -> Option> { - if self.on_path.contains_key(current) { - return None; - } - - if self.visited.contains(current) { - return None; - } - - self.visited.insert(current); - self.on_path.insert(current, self.path.len()); - self.path.push(current); - - if let Some(neighbors) = self.adj.get(current) { - for (target, source_id, range) in neighbors { - if let Some(&start_index) = self.on_path.get(target) { - // Cycle detected! - let mut chain = Vec::new(); - for i in start_index..self.path.len() - 1 { - let (src, rng) = self.edges[i]; - chain.push((src, rng, self.path[i + 1])); - } - chain.push((*source_id, *range, *target)); - return Some(chain); - } - - self.edges.push((*source_id, *range)); - if let Some(chain) = self.dfs(target) { - return Some(chain); - } - self.edges.pop(); - } - } - - self.path.pop(); - self.on_path.swap_remove(current); - None - } -} - -fn expr_has_escape(expr: &Expr, scc_names: &IndexSet<&str>) -> bool { - match expr { - Expr::Ref(r) => { - let Some(name_token) = r.name() else { - return true; - }; - !scc_names.contains(name_token.text()) - } - Expr::NamedNode(node) => { - let children: Vec<_> = node.children().collect(); - children.is_empty() || children.iter().all(|c| expr_has_escape(c, scc_names)) - } - Expr::AltExpr(_) => expr - .children() - .iter() - .any(|c| expr_has_escape(c, scc_names)), - Expr::SeqExpr(_) => expr - .children() - .iter() - .all(|c| expr_has_escape(c, scc_names)), - Expr::QuantifiedExpr(q) => { - if q.is_optional() { - return true; - } - q.inner() - .map(|inner| expr_has_escape(&inner, scc_names)) - .unwrap_or(true) - } - Expr::CapturedExpr(_) | Expr::FieldExpr(_) => expr - .children() - .iter() - .all(|c| expr_has_escape(c, scc_names)), - Expr::AnonymousNode(_) => true, - } -} - -fn expr_guarantees_consumption(expr: &Expr) -> bool { - match expr { - Expr::NamedNode(_) | Expr::AnonymousNode(_) => true, - Expr::Ref(_) => false, - Expr::AltExpr(_) => expr.children().iter().all(expr_guarantees_consumption), - Expr::SeqExpr(_) => expr.children().iter().any(expr_guarantees_consumption), - Expr::QuantifiedExpr(q) => { - !q.is_optional() - && q.inner() - .map(|i| expr_guarantees_consumption(&i)) - .unwrap_or(false) - } - Expr::CapturedExpr(_) | Expr::FieldExpr(_) => { - expr.children().iter().all(expr_guarantees_consumption) - } - } -} - -fn collect_refs<'a>(expr: &Expr, symbol_table: &'a SymbolTable) -> IndexSet<&'a str> { +/// Collect references to definitions within the symbol table. +/// +/// Returns only refs that point to defined names (filters out node type references). +pub(super) fn collect_refs<'a>(expr: &Expr, symbol_table: &'a SymbolTable) -> IndexSet<&'a str> { let mut refs = IndexSet::new(); for descendant in expr.as_cst().descendants() { let Some(r) = Ref::cast(descendant) else { @@ -540,82 +224,3 @@ fn collect_refs<'a>(expr: &Expr, symbol_table: &'a SymbolTable) -> IndexSet<&'a } refs } - -/// Whether to search for any reference or only unguarded ones. -#[derive(Clone, Copy, PartialEq, Eq)] -enum RefSearchMode { - /// Find any reference to the target. - Any, - /// Find only unguarded references (not inside a NamedNode/AnonymousNode). - Unguarded, -} - -struct RefFinder<'a> { - target: &'a str, - found: Option, - mode: RefSearchMode, -} - -impl Visitor for RefFinder<'_> { - fn visit_expr(&mut self, expr: &Expr) { - if self.found.is_some() { - return; - } - walk_expr(self, expr); - } - - fn visit_named_node(&mut self, node: &NamedNode) { - if self.mode == RefSearchMode::Unguarded { - return; // Guarded: stop recursion - } - super::visitor::walk_named_node(self, node); - } - - fn visit_anonymous_node(&mut self, _node: &AnonymousNode) { - // AnonymousNode has no child expressions, so nothing to walk. - // In Unguarded mode this also acts as a guard (stops recursion). - } - - fn visit_ref(&mut self, r: &Ref) { - if self.found.is_some() { - return; - } - if let Some(name) = r.name() - && name.text() == self.target - { - self.found = Some(name.text_range()); - } - } - - fn visit_seq_expr(&mut self, seq: &SeqExpr) { - for child in seq.children() { - self.visit_expr(&child); - if self.found.is_some() { - return; - } - if self.mode == RefSearchMode::Unguarded && expr_guarantees_consumption(&child) { - return; - } - } - } -} - -fn find_ref_range(expr: &Expr, target: &str) -> Option { - let mut visitor = RefFinder { - target, - found: None, - mode: RefSearchMode::Any, - }; - visitor.visit_expr(expr); - visitor.found -} - -fn find_unguarded_ref_range(expr: &Expr, target: &str) -> Option { - let mut visitor = RefFinder { - target, - found: None, - mode: RefSearchMode::Unguarded, - }; - visitor.visit_expr(expr); - visitor.found -} diff --git a/crates/plotnik-lib/src/analyze/mod.rs b/crates/plotnik-lib/src/analyze/mod.rs index f487cede..93a9c445 100644 --- a/crates/plotnik-lib/src/analyze/mod.rs +++ b/crates/plotnik-lib/src/analyze/mod.rs @@ -10,6 +10,7 @@ pub mod dependencies; mod invariants; pub mod link; +mod recursion; pub mod refs; pub mod symbol_table; pub mod type_check; @@ -26,6 +27,7 @@ mod symbol_table_tests; pub use dependencies::DependencyAnalysis; pub use link::LinkOutput; +pub use recursion::validate_recursion; pub use symbol_table::{SymbolTable, UNNAMED_DEF}; pub use type_check::{TypeContext, infer_types, primary_def_name}; pub use validation::{validate_alt_kinds, validate_anchors}; diff --git a/crates/plotnik-lib/src/analyze/recursion.rs b/crates/plotnik-lib/src/analyze/recursion.rs new file mode 100644 index 00000000..e2783459 --- /dev/null +++ b/crates/plotnik-lib/src/analyze/recursion.rs @@ -0,0 +1,407 @@ +//! Recursion validation for definitions. +//! +//! Validates that recursive definitions are well-formed: +//! - Escapable: at least one non-recursive path exists +//! - Guarded: every recursive cycle consumes input + +use indexmap::{IndexMap, IndexSet}; +use rowan::TextRange; + +use super::dependencies::{DependencyAnalysis, collect_refs}; +use super::symbol_table::SymbolTable; +use super::visitor::{Visitor, walk_expr, walk_named_node}; +use crate::Diagnostics; +use crate::diagnostics::DiagnosticKind; +use crate::parser::{AnonymousNode, Def, Expr, NamedNode, Ref, Root, SeqExpr}; +use crate::query::source_map::SourceId; + +/// Validate recursion using the pre-computed dependency analysis. +pub fn validate_recursion( + analysis: &DependencyAnalysis, + ast_map: &IndexMap, + symbol_table: &SymbolTable, + diag: &mut Diagnostics, +) { + let mut validator = RecursionValidator { + ast_map, + symbol_table, + diag, + }; + validator.validate(&analysis.sccs); +} + +struct RecursionValidator<'a, 'd> { + ast_map: &'a IndexMap, + symbol_table: &'a SymbolTable, + diag: &'d mut Diagnostics, +} + +impl<'a, 'd> RecursionValidator<'a, 'd> { + fn validate(&mut self, sccs: &[Vec]) { + for scc in sccs { + self.validate_scc(scc); + } + } + + fn validate_scc(&mut self, scc: &[String]) { + // Filter out trivial non-recursive components. + // A component is recursive if it has >1 node, or 1 node that references itself. + if scc.len() == 1 { + let name = &scc[0]; + let body = self + .symbol_table + .get(name) + .expect("node in SCC must exist in symbol table"); + if !collect_refs(body, self.symbol_table).contains(name.as_str()) { + return; + } + } + + let scc_set: IndexSet<&str> = scc.iter().map(String::as_str).collect(); + + // 1. Check for infinite tree structure (Escape Analysis) + // A valid recursive definition must have a non-recursive path. + // If NO definition in the SCC has an escape path, the whole group is invalid. + let has_escape = scc + .iter() + .filter_map(|name| self.symbol_table.get(name)) + .any(|body| expr_has_escape(body, &scc_set)); + + if !has_escape { + // Find a cycle to report. Any cycle within the SCC is an infinite recursion loop + // because there are no escape paths. + if let Some(raw_chain) = self.find_cycle(scc, &scc_set, |_, _, expr, target| { + find_ref_range(expr, target) + }) { + let chain = self.format_chain(raw_chain, false); + self.report_cycle(DiagnosticKind::RecursionNoEscape, scc, chain); + } + return; + } + + // 2. Check for infinite loops (Guarded Recursion Analysis) + // Even if there is an escape, every recursive cycle must consume input (be guarded). + // We look for a cycle composed entirely of unguarded references. + if let Some(raw_chain) = self.find_cycle(scc, &scc_set, |_, _, expr, target| { + find_unguarded_ref_range(expr, target) + }) { + let chain = self.format_chain(raw_chain, true); + self.report_cycle(DiagnosticKind::DirectRecursion, scc, chain); + } + } + + /// Finds a cycle within the given set of nodes (SCC). + /// `get_edge_location` returns the location of a reference from `expr` to `target`. + fn find_cycle<'s>( + &self, + nodes: &'s [String], + domain: &IndexSet<&'s str>, + get_edge_location: impl Fn(&Self, SourceId, &Expr, &str) -> Option, + ) -> Option> { + let mut adj = IndexMap::new(); + for name in nodes { + if let Some((source_id, body)) = self.symbol_table.get_full(name) { + let neighbors = domain + .iter() + .filter_map(|target| { + get_edge_location(self, source_id, body, target) + .map(|range| (*target, source_id, range)) + }) + .collect::>(); + adj.insert(name.as_str(), neighbors); + } + } + + let node_strs: Vec<&str> = nodes.iter().map(String::as_str).collect(); + CycleFinder::find(&node_strs, &adj) + } + + fn format_chain( + &self, + raw_chain: Vec<(SourceId, TextRange, &str)>, + is_unguarded: bool, + ) -> Vec<(SourceId, TextRange, String)> { + if raw_chain.len() == 1 { + let (source_id, range, target) = &raw_chain[0]; + let msg = if is_unguarded { + "references itself".to_string() + } else { + format!("{} references itself", target) + }; + return vec![(*source_id, *range, msg)]; + } + + let len = raw_chain.len(); + raw_chain + .into_iter() + .enumerate() + .map(|(i, (source_id, range, target))| { + let msg = if i == len - 1 { + format!("references {} (completing cycle)", target) + } else { + format!("references {}", target) + }; + (source_id, range, msg) + }) + .collect() + } + + fn report_cycle( + &mut self, + kind: DiagnosticKind, + scc: &[String], + chain: Vec<(SourceId, TextRange, String)>, + ) { + let (primary_source, primary_loc) = chain + .first() + .map(|(s, r, _)| (*s, *r)) + .unwrap_or_else(|| (SourceId::default(), TextRange::empty(0.into()))); + + let related_def = if scc.len() > 1 { + self.find_def_info_containing(scc, primary_loc) + } else { + None + }; + + let mut builder = self.diag.report(primary_source, kind, primary_loc); + + for (source_id, range, msg) in chain { + builder = builder.related_to(source_id, range, msg); + } + + if let Some((source_id, msg, range)) = related_def { + builder = builder.related_to(source_id, range, msg); + } + + builder.emit(); + } + + fn find_def_info_containing( + &self, + scc: &[String], + range: TextRange, + ) -> Option<(SourceId, String, TextRange)> { + let name = scc.iter().find(|name| { + self.symbol_table + .get(name.as_str()) + .is_some_and(|body| body.text_range().contains_range(range)) + })?; + let (source_id, def) = self.find_def_by_name(name)?; + let n = def.name()?; + Some(( + source_id, + format!("{} is defined here", name), + n.text_range(), + )) + } + + fn find_def_by_name(&self, name: &str) -> Option<(SourceId, Def)> { + self.ast_map.iter().find_map(|(source_id, ast)| { + ast.defs() + .find(|d| d.name().map(|n| n.text() == name).unwrap_or(false)) + .map(|def| (*source_id, def)) + }) + } +} + +struct CycleFinder<'a, 'q> { + adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>, + visited: IndexSet<&'q str>, + on_path: IndexMap<&'q str, usize>, + path: Vec<&'q str>, + edges: Vec<(SourceId, TextRange)>, +} + +impl<'a, 'q> CycleFinder<'a, 'q> { + fn find( + nodes: &[&'q str], + adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>, + ) -> Option> { + let mut finder = Self { + adj, + visited: IndexSet::new(), + on_path: IndexMap::new(), + path: Vec::new(), + edges: Vec::new(), + }; + + for start in nodes { + if let Some(chain) = finder.dfs(start) { + return Some(chain); + } + } + None + } + + fn dfs(&mut self, current: &'q str) -> Option> { + if self.on_path.contains_key(current) { + return None; + } + + if self.visited.contains(current) { + return None; + } + + self.visited.insert(current); + self.on_path.insert(current, self.path.len()); + self.path.push(current); + + if let Some(neighbors) = self.adj.get(current) { + for (target, source_id, range) in neighbors { + if let Some(&start_index) = self.on_path.get(target) { + // Cycle detected! + let mut chain = Vec::new(); + for i in start_index..self.path.len() - 1 { + let (src, rng) = self.edges[i]; + chain.push((src, rng, self.path[i + 1])); + } + chain.push((*source_id, *range, *target)); + return Some(chain); + } + + self.edges.push((*source_id, *range)); + if let Some(chain) = self.dfs(target) { + return Some(chain); + } + self.edges.pop(); + } + } + + self.path.pop(); + self.on_path.swap_remove(current); + None + } +} + +fn expr_has_escape(expr: &Expr, scc_names: &IndexSet<&str>) -> bool { + match expr { + Expr::Ref(r) => { + let Some(name_token) = r.name() else { + return true; + }; + !scc_names.contains(name_token.text()) + } + Expr::NamedNode(node) => { + let children: Vec<_> = node.children().collect(); + children.is_empty() || children.iter().all(|c| expr_has_escape(c, scc_names)) + } + Expr::AltExpr(_) => expr + .children() + .iter() + .any(|c| expr_has_escape(c, scc_names)), + Expr::SeqExpr(_) => expr + .children() + .iter() + .all(|c| expr_has_escape(c, scc_names)), + Expr::QuantifiedExpr(q) => { + if q.is_optional() { + return true; + } + q.inner() + .map(|inner| expr_has_escape(&inner, scc_names)) + .unwrap_or(true) + } + Expr::CapturedExpr(_) | Expr::FieldExpr(_) => expr + .children() + .iter() + .all(|c| expr_has_escape(c, scc_names)), + Expr::AnonymousNode(_) => true, + } +} + +fn expr_guarantees_consumption(expr: &Expr) -> bool { + match expr { + Expr::NamedNode(_) | Expr::AnonymousNode(_) => true, + Expr::Ref(_) => false, + Expr::AltExpr(_) => expr.children().iter().all(expr_guarantees_consumption), + Expr::SeqExpr(_) => expr.children().iter().any(expr_guarantees_consumption), + Expr::QuantifiedExpr(q) => { + !q.is_optional() + && q.inner() + .map(|i| expr_guarantees_consumption(&i)) + .unwrap_or(false) + } + Expr::CapturedExpr(_) | Expr::FieldExpr(_) => { + expr.children().iter().all(expr_guarantees_consumption) + } + } +} + +/// Whether to search for any reference or only unguarded ones. +#[derive(Clone, Copy, PartialEq, Eq)] +enum RefSearchMode { + /// Find any reference to the target. + Any, + /// Find only unguarded references (not inside a NamedNode/AnonymousNode). + Unguarded, +} + +struct RefFinder<'a> { + target: &'a str, + found: Option, + mode: RefSearchMode, +} + +impl Visitor for RefFinder<'_> { + fn visit_expr(&mut self, expr: &Expr) { + if self.found.is_some() { + return; + } + walk_expr(self, expr); + } + + fn visit_named_node(&mut self, node: &NamedNode) { + if self.mode == RefSearchMode::Unguarded { + return; // Guarded: stop recursion + } + walk_named_node(self, node); + } + + fn visit_anonymous_node(&mut self, _node: &AnonymousNode) { + // AnonymousNode has no child expressions, so nothing to walk. + // In Unguarded mode this also acts as a guard (stops recursion). + } + + fn visit_ref(&mut self, r: &Ref) { + if self.found.is_some() { + return; + } + if let Some(name) = r.name() + && name.text() == self.target + { + self.found = Some(name.text_range()); + } + } + + fn visit_seq_expr(&mut self, seq: &SeqExpr) { + for child in seq.children() { + self.visit_expr(&child); + if self.found.is_some() { + return; + } + if self.mode == RefSearchMode::Unguarded && expr_guarantees_consumption(&child) { + return; + } + } + } +} + +fn find_ref_range(expr: &Expr, target: &str) -> Option { + let mut visitor = RefFinder { + target, + found: None, + mode: RefSearchMode::Any, + }; + visitor.visit_expr(expr); + visitor.found +} + +fn find_unguarded_ref_range(expr: &Expr, target: &str) -> Option { + let mut visitor = RefFinder { + target, + found: None, + mode: RefSearchMode::Unguarded, + }; + visitor.visit_expr(expr); + visitor.found +} diff --git a/crates/plotnik-lib/src/analyze/refs.rs b/crates/plotnik-lib/src/analyze/refs.rs index 89285516..43f84a2f 100644 --- a/crates/plotnik-lib/src/analyze/refs.rs +++ b/crates/plotnik-lib/src/analyze/refs.rs @@ -10,7 +10,7 @@ pub fn ref_nodes(expr: &Expr) -> impl Iterator + '_ { } /// Collect all reference names as owned strings. -pub fn collect_ref_names(expr: &Expr) -> IndexSet { +pub fn ref_names(expr: &Expr) -> IndexSet { ref_nodes(expr) .filter_map(|r| r.name()) .map(|tok| tok.text().to_string()) @@ -33,7 +33,7 @@ mod tests { fn collect_refs_from_simple_ref() { let q = Query::expect("Q = (Foo)"); let expr = q.symbol_table.get("Q").unwrap(); - let refs = collect_ref_names(expr); + let refs = ref_names(expr); assert_eq!(refs.len(), 1); assert!(refs.contains("Foo")); } @@ -42,7 +42,7 @@ mod tests { fn collect_refs_from_nested() { let q = Query::expect("Q = (x (Foo) (Bar))"); let expr = q.symbol_table.get("Q").unwrap(); - let refs = collect_ref_names(expr); + let refs = ref_names(expr); assert_eq!(refs.len(), 2); assert!(refs.contains("Foo")); assert!(refs.contains("Bar")); @@ -52,7 +52,7 @@ mod tests { fn collect_refs_deduplicates() { let q = Query::expect("Q = {(Foo) (Foo)}"); let expr = q.symbol_table.get("Q").unwrap(); - let refs = collect_ref_names(expr); + let refs = ref_names(expr); assert_eq!(refs.len(), 1); } diff --git a/crates/plotnik-lib/src/analyze/type_check/infer.rs b/crates/plotnik-lib/src/analyze/type_check/infer.rs index ab216f4d..65d9cd73 100644 --- a/crates/plotnik-lib/src/analyze/type_check/infer.rs +++ b/crates/plotnik-lib/src/analyze/type_check/infer.rs @@ -6,6 +6,7 @@ use std::collections::BTreeMap; use std::collections::btree_map::Entry; +use indexmap::IndexMap; use plotnik_core::Interner; use rowan::TextRange; @@ -17,6 +18,7 @@ use super::types::{ }; use super::unify::{UnifyError, unify_flows}; +use crate::analyze::dependencies::DependencyAnalysis; use crate::analyze::symbol_table::SymbolTable; use crate::analyze::visitor::{Visitor, walk_alt_expr, walk_def, walk_named_node, walk_seq_expr}; use crate::diagnostics::{DiagnosticKind, Diagnostics}; @@ -874,7 +876,7 @@ impl Visitor for InferenceVisitor<'_, '_> { } /// Run inference on all definitions in a root. -pub fn infer_root( +fn infer_root( ctx: &mut TypeContext, interner: &mut Interner, symbol_table: &SymbolTable, @@ -885,3 +887,115 @@ pub fn infer_root( let mut visitor = InferenceVisitor::new(ctx, interner, symbol_table, source_id, diag); visitor.visit(root); } + +/// Orchestrates type inference across all definitions in dependency order. +pub(super) struct InferencePass<'a> { + ctx: TypeContext, + interner: &'a mut Interner, + ast_map: &'a IndexMap, + symbol_table: &'a SymbolTable, + dependency_analysis: &'a DependencyAnalysis, + diag: &'a mut Diagnostics, +} + +impl<'a> InferencePass<'a> { + pub fn new( + interner: &'a mut Interner, + ast_map: &'a IndexMap, + symbol_table: &'a SymbolTable, + dependency_analysis: &'a DependencyAnalysis, + diag: &'a mut Diagnostics, + ) -> Self { + Self { + ctx: TypeContext::new(), + interner, + ast_map, + symbol_table, + dependency_analysis, + diag, + } + } + + pub fn run(mut self) -> TypeContext { + // Avoid re-registration of definitions + self.ctx.seed_defs( + self.dependency_analysis.def_names(), + self.dependency_analysis.name_to_def(), + ); + + self.mark_recursion(); + self.process_sccs(); + self.process_orphans(); + + self.ctx + } + + /// Identify and mark recursive definitions. + fn mark_recursion(&mut self) { + for scc in &self.dependency_analysis.sccs { + for def_name in scc { + if !self.dependency_analysis.is_recursive(def_name) { + continue; + } + let sym = self.interner.intern(def_name); + let Some(def_id) = self.ctx.get_def_id_sym(sym) else { + continue; + }; + self.ctx.mark_recursive(def_id); + } + } + } + + /// Process definitions in SCC order (leaves first). + fn process_sccs(&mut self) { + for scc in &self.dependency_analysis.sccs { + for def_name in scc { + if let Some(source_id) = self.symbol_table.source_id(def_name) { + self.infer_and_register(def_name, source_id); + } + } + } + } + + /// Handle any definitions not in an SCC (safety net). + fn process_orphans(&mut self) { + for (name, source_id, _body) in self.symbol_table.iter_full() { + // Skip if already processed + if self.ctx.get_def_type_by_name(self.interner, name).is_some() { + continue; + } + self.infer_and_register(name, source_id); + } + } + + fn infer_and_register(&mut self, def_name: &str, source_id: SourceId) { + let Some(root) = self.ast_map.get(&source_id) else { + return; + }; + + infer_root( + &mut self.ctx, + self.interner, + self.symbol_table, + source_id, + root, + self.diag, + ); + + // Register the definition's output type based on the inferred body flow + if let Some(body) = self.symbol_table.get(def_name) + && let Some(info) = self.ctx.get_term_info(body).cloned() + { + let type_id = self.flow_to_type_id(&info.flow); + self.ctx + .set_def_type_by_name(self.interner, def_name, type_id); + } + } + + fn flow_to_type_id(&mut self, flow: &TypeFlow) -> TypeId { + match flow { + TypeFlow::Void => TYPE_VOID, + TypeFlow::Scalar(id) | TypeFlow::Bubble(id) => *id, + } + } +} diff --git a/crates/plotnik-lib/src/analyze/type_check/mod.rs b/crates/plotnik-lib/src/analyze/type_check/mod.rs index afda02b8..fda92fb2 100644 --- a/crates/plotnik-lib/src/analyze/type_check/mod.rs +++ b/crates/plotnik-lib/src/analyze/type_check/mod.rs @@ -28,8 +28,6 @@ use crate::diagnostics::Diagnostics; use crate::parser::ast::Root; use crate::query::source_map::SourceId; -use infer::infer_root; - /// Run type inference on all definitions. /// /// Processes definitions in dependency order (leaves first) to handle @@ -41,110 +39,7 @@ pub fn infer_types( dependency_analysis: &DependencyAnalysis, diag: &mut Diagnostics, ) -> TypeContext { - let ctx = TypeContext::new(); - InferencePass { - ctx, - interner, - ast_map, - symbol_table, - dependency_analysis, - diag, - } - .run() -} - -struct InferencePass<'a> { - ctx: TypeContext, - interner: &'a mut Interner, - ast_map: &'a IndexMap, - symbol_table: &'a SymbolTable, - dependency_analysis: &'a DependencyAnalysis, - diag: &'a mut Diagnostics, -} - -impl<'a> InferencePass<'a> { - fn run(mut self) -> TypeContext { - // Avoid re-registration of definitions - self.ctx.seed_defs( - self.dependency_analysis.def_names(), - self.dependency_analysis.name_to_def(), - ); - - self.mark_recursion(); - self.process_sccs(); - self.process_orphans(); - - self.ctx - } - - /// Identify and mark recursive definitions. - fn mark_recursion(&mut self) { - for scc in &self.dependency_analysis.sccs { - for def_name in scc { - if !self.dependency_analysis.is_recursive(def_name) { - continue; - } - let sym = self.interner.intern(def_name); - let Some(def_id) = self.ctx.get_def_id_sym(sym) else { - continue; - }; - self.ctx.mark_recursive(def_id); - } - } - } - - /// Process definitions in SCC order (leaves first). - fn process_sccs(&mut self) { - for scc in &self.dependency_analysis.sccs { - for def_name in scc { - if let Some(source_id) = self.symbol_table.source_id(def_name) { - self.infer_and_register(def_name, source_id); - } - } - } - } - - /// Handle any definitions not in an SCC (safety net). - fn process_orphans(&mut self) { - for (name, source_id, _body) in self.symbol_table.iter_full() { - // Skip if already processed - if self.ctx.get_def_type_by_name(self.interner, name).is_some() { - continue; - } - self.infer_and_register(name, source_id); - } - } - - fn infer_and_register(&mut self, def_name: &str, source_id: SourceId) { - let Some(root) = self.ast_map.get(&source_id) else { - return; - }; - - infer_root( - &mut self.ctx, - self.interner, - self.symbol_table, - source_id, - root, - self.diag, - ); - - // Register the definition's output type based on the inferred body flow - if let Some(body) = self.symbol_table.get(def_name) - && let Some(info) = self.ctx.get_term_info(body).cloned() - { - let type_id = self.flow_to_type_id(&info.flow); - self.ctx - .set_def_type_by_name(self.interner, def_name, type_id); - } - } - - fn flow_to_type_id(&mut self, flow: &TypeFlow) -> TypeId { - match flow { - TypeFlow::Void => TYPE_VOID, - TypeFlow::Scalar(id) | TypeFlow::Bubble(id) => *id, - } - } + infer::InferencePass::new(interner, ast_map, symbol_table, dependency_analysis, diag).run() } /// Get the primary definition name (first non-underscore, or underscore if none). diff --git a/crates/plotnik-lib/src/query/printer.rs b/crates/plotnik-lib/src/query/printer.rs index 584662ed..572a640a 100644 --- a/crates/plotnik-lib/src/query/printer.rs +++ b/crates/plotnik-lib/src/query/printer.rs @@ -167,7 +167,7 @@ impl<'q> QueryPrinter<'q> { visited.insert(name.to_string()); if let Some(body) = self.query.symbol_table.get(name) { - let refs_set = crate::analyze::refs::collect_ref_names(body); + let refs_set = crate::analyze::refs::ref_names(body); let mut refs: Vec<_> = refs_set.iter().map(|s| s.as_str()).collect(); refs.sort(); for r in refs { diff --git a/crates/plotnik-lib/src/query/query.rs b/crates/plotnik-lib/src/query/query.rs index 27c0e128..4671b33a 100644 --- a/crates/plotnik-lib/src/query/query.rs +++ b/crates/plotnik-lib/src/query/query.rs @@ -6,11 +6,11 @@ use plotnik_core::{Interner, NodeFieldId, NodeTypeId, Symbol}; use plotnik_langs::Lang; use crate::Diagnostics; -use crate::analyze::dependencies; use crate::analyze::link; use crate::analyze::symbol_table::{SymbolTable, resolve_names}; use crate::analyze::type_check::{self, Arity, TypeContext}; use crate::analyze::validation::{validate_alt_kinds, validate_anchors}; +use crate::analyze::{dependencies, validate_recursion}; use crate::parser::{Parser, Root, SyntaxNode, lexer::lex}; use crate::query::source_map::{SourceId, SourceMap}; @@ -110,7 +110,7 @@ impl QueryParsed { let symbol_table = resolve_names(&self.source_map, &self.ast_map, &mut self.diag); let dependency_analysis = dependencies::analyze_dependencies(&symbol_table, &mut interner); - dependencies::validate_recursion( + validate_recursion( &dependency_analysis, &self.ast_map, &symbol_table,