diff --git a/crates/plotnik-lib/src/query/recursion.rs b/crates/plotnik-lib/src/query/recursion.rs index ae579bac..c0f39597 100644 --- a/crates/plotnik-lib/src/query/recursion.rs +++ b/crates/plotnik-lib/src/query/recursion.rs @@ -13,175 +13,152 @@ use crate::parser::{Def, Expr}; impl Query<'_> { pub(super) fn validate_recursion(&mut self) { - let sccs = self.find_sccs(); + let sccs = SccFinder::find(self); for scc in sccs { - let scc_set: IndexSet<&str> = scc.iter().map(|s| s.as_str()).collect(); - - // 1. Check for infinite tree structure (Escape Analysis) - // Existing logic: at least one definition must have a non-recursive path. - let has_escape = scc.iter().any(|name| { - self.symbol_table - .get(name.as_str()) - .map(|body| expr_has_escape(body, &scc_set)) - .unwrap_or(true) - }); + self.validate_scc(scc); + } + } - if !has_escape { - let chain = if scc.len() == 1 { - self.build_self_ref_chain(&scc[0]) - } else { - self.build_cycle_chain(&scc) - }; - self.emit_recursion_error(&scc[0], &scc, chain); - continue; - } + fn validate_scc(&mut self, scc: Vec) { + let scc_set: IndexSet<&str> = scc.iter().map(|s| s.as_str()).collect(); - // 2. Check for infinite loops (Guarded Recursion Analysis) - // Ensure every recursive cycle consumes at least one node. - if let Some(cycle) = self.find_unguarded_cycle(&scc, &scc_set) { - let chain = self.build_unguarded_chain(&cycle); - self.emit_direct_recursion_error(&cycle[0], &cycle, chain); + // 1. Check for infinite tree structure (Escape Analysis) + // A valid recursive definition must have a non-recursive path. + // If NO definition in the SCC has an escape path, the whole group is invalid. + let has_escape = scc.iter().any(|name| { + self.symbol_table + .get(name.as_str()) + .map(|body| expr_has_escape(body, &scc_set)) + .unwrap_or(true) + }); + + if !has_escape { + // Find a cycle to report. Any cycle within the SCC is an infinite recursion loop + // because there are no escape paths. + if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |q, expr, target| { + q.find_ref_range(expr, target) + }) { + let chain = self.format_chain(raw_chain, false); + self.report_cycle(DiagnosticKind::RecursionNoEscape, &scc, chain); } + return; } - } - fn find_sccs(&self) -> Vec> { - struct State<'a, 'src> { - query: &'a Query<'src>, - index: usize, - stack: Vec, - on_stack: IndexSet, - indices: IndexMap, - lowlinks: IndexMap, - sccs: Vec>, + // 2. Check for infinite loops (Guarded Recursion Analysis) + // Even if there is an escape, every recursive cycle must consume input (be guarded). + // We look for a cycle composed entirely of unguarded references. + if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |q, expr, target| { + q.find_unguarded_ref_range(expr, target) + }) { + let chain = self.format_chain(raw_chain, true); + self.report_cycle(DiagnosticKind::DirectRecursion, &scc, chain); } + } - fn strongconnect(name: &str, state: &mut State<'_, '_>) { - state.indices.insert(name.to_string(), state.index); - state.lowlinks.insert(name.to_string(), state.index); - state.index += 1; - state.stack.push(name.to_string()); - state.on_stack.insert(name.to_string()); - - if let Some(body) = state.query.symbol_table.get(name) { - let refs = collect_refs(body); - for ref_name in &refs { - if state.query.symbol_table.get(ref_name.as_str()).is_none() { - continue; - } - if !state.indices.contains_key(ref_name.as_str()) { - strongconnect(ref_name, state); - let ref_lowlink = state.lowlinks[ref_name.as_str()]; - let my_lowlink = state.lowlinks.get_mut(name).unwrap(); - *my_lowlink = (*my_lowlink).min(ref_lowlink); - } else if state.on_stack.contains(ref_name.as_str()) { - let ref_index = state.indices[ref_name.as_str()]; - let my_lowlink = state.lowlinks.get_mut(name).unwrap(); - *my_lowlink = (*my_lowlink).min(ref_index); - } - } - } - - if state.lowlinks[name] == state.indices[name] { - let mut scc = Vec::new(); - loop { - let w = state.stack.pop().unwrap(); - state.on_stack.swap_remove(&w); - scc.push(w.clone()); - if w == name { - break; - } - } - state.sccs.push(scc); + /// Finds a cycle within the given set of nodes (SCC). + /// `get_edge_location` returns the location of a reference from `expr` to `target`. + fn find_cycle( + &self, + nodes: &[String], + domain: &IndexSet<&str>, + get_edge_location: impl Fn(&Query, &Expr, &str) -> Option, + ) -> Option> { + let mut adj = IndexMap::new(); + for name in nodes { + if let Some(body) = self.symbol_table.get(name.as_str()) { + let neighbors = domain + .iter() + .filter_map(|target| { + get_edge_location(self, body, target) + .map(|range| (target.to_string(), range)) + }) + .collect::>(); + adj.insert(name.clone(), neighbors); } } - let mut state = State { - query: self, - index: 0, - stack: Vec::new(), - on_stack: IndexSet::new(), - indices: IndexMap::new(), - lowlinks: IndexMap::new(), - sccs: Vec::new(), - }; + CycleFinder::find(nodes, &adj) + } - for name in self.symbol_table.keys() { - if !state.indices.contains_key(*name) { - strongconnect(name, &mut state); - } + fn format_chain( + &self, + chain: Vec<(TextRange, String)>, + is_unguarded: bool, + ) -> Vec<(TextRange, String)> { + if chain.len() == 1 { + let (range, target) = &chain[0]; + let msg = if is_unguarded { + "references itself".to_string() + } else { + format!("{} references itself", target) + }; + return vec![(*range, msg)]; } - state - .sccs + let len = chain.len(); + chain .into_iter() - .filter(|scc| { - scc.len() > 1 - || self - .symbol_table - .get(scc[0].as_str()) - .map(|body| collect_refs(body).contains(scc[0].as_str())) - .unwrap_or(false) + .enumerate() + .map(|(i, (range, target))| { + let msg = if i == len - 1 { + format!("references {} (completing cycle)", target) + } else { + format!("references {}", target) + }; + (range, msg) }) .collect() } - fn find_unguarded_cycle( - &self, + fn report_cycle( + &mut self, + kind: DiagnosticKind, scc: &[String], - scc_set: &IndexSet<&str>, - ) -> Option> { - // Build dependency graph for unguarded calls within the SCC - let mut adj = IndexMap::new(); - for name in scc { - if let Some(body) = self.symbol_table.get(name.as_str()) { - let mut refs = IndexSet::new(); - collect_unguarded_refs(body, scc_set, &mut refs); - adj.insert(name.clone(), refs); - } - } - - // Detect cycle - let mut visited = IndexSet::new(); - let mut stack = IndexSet::new(); + chain: Vec<(TextRange, String)>, + ) { + let primary_loc = chain + .first() + .map(|(r, _)| *r) + .unwrap_or_else(|| TextRange::empty(0.into())); - for start_node in scc { - if let Some(target) = Self::detect_cycle(start_node, &adj, &mut visited, &mut stack) { - let index = stack.get_index_of(&target).unwrap(); - return Some(stack.iter().skip(index).cloned().collect()); - } - } + let related_def = if scc.len() > 1 { + self.find_def_info_containing(scc, primary_loc) + } else { + None + }; - None - } + let mut builder = self.recursion_diagnostics.report(kind, primary_loc); - fn detect_cycle( - node: &String, - adj: &IndexMap>, - visited: &mut IndexSet, - stack: &mut IndexSet, - ) -> Option { - if stack.contains(node) { - return Some(node.clone()); + for (range, msg) in chain { + builder = builder.related_to(msg, range); } - if visited.contains(node) { - return None; - } - - visited.insert(node.clone()); - stack.insert(node.clone()); - if let Some(neighbors) = adj.get(node) { - for neighbor in neighbors { - if let Some(target) = Self::detect_cycle(neighbor, adj, visited, stack) { - return Some(target); - } - } + if let Some((msg, range)) = related_def { + builder = builder.related_to(msg, range); } - stack.pop(); - None + builder.emit(); + } + + fn find_def_info_containing( + &self, + scc: &[String], + range: TextRange, + ) -> Option<(String, TextRange)> { + scc.iter() + .find(|name| { + self.symbol_table + .get(name.as_str()) + .map(|body| body.text_range().contains_range(range)) + .unwrap_or(false) + }) + .and_then(|name| { + self.find_def_by_name(name).and_then(|def| { + def.name() + .map(|n| (format!("{} is defined here", name), n.text_range())) + }) + }) } fn find_def_by_name(&self, name: &str) -> Option { @@ -190,177 +167,167 @@ impl Query<'_> { .find(|d| d.name().map(|n| n.text() == name).unwrap_or(false)) } - fn find_reference_location(&self, from: &str, to: &str) -> Option { - let def = self.find_def_by_name(from)?; - let body = def.body()?; - find_ref_in_expr(&body, to) + fn find_ref_range(&self, expr: &Expr, target: &str) -> Option { + find_ref_in_expr(expr, target) } - fn find_unguarded_reference_location(&self, from: &str, to: &str) -> Option { - let def = self.find_def_by_name(from)?; - let body = def.body()?; - find_unguarded_ref_in_expr(&body, to) + fn find_unguarded_ref_range(&self, expr: &Expr, target: &str) -> Option { + find_unguarded_ref_in_expr(expr, target) } +} - fn build_self_ref_chain(&self, name: &str) -> Vec<(TextRange, String)> { - self.find_reference_location(name, name) - .map(|range| vec![(range, format!("{} references itself", name))]) - .unwrap_or_default() - } +struct CycleFinder<'a> { + adj: &'a IndexMap>, + visited: IndexSet, + on_path: IndexMap, + path: Vec, + edges: Vec, +} - fn build_cycle_chain(&self, scc: &[String]) -> Vec<(TextRange, String)> { - // Since Tarjan's sccs are not guaranteed to be ordered as a cycle, - // we need to find the cycle path explicitly. - let scc_set: IndexSet<&str> = scc.iter().map(|s| s.as_str()).collect(); - let mut visited = IndexSet::new(); - let mut path = Vec::new(); - let start = &scc[0]; - - fn find_path<'a>( - current: &str, - start: &str, - scc_set: &IndexSet<&str>, - query: &Query<'a>, - visited: &mut IndexSet, - path: &mut Vec, - ) -> bool { - if visited.contains(current) { - return current == start && path.len() > 1; - } - visited.insert(current.to_string()); - path.push(current.to_string()); - - if let Some(body) = query.symbol_table.get(current) { - let refs = collect_refs(body); - for ref_name in &refs { - if scc_set.contains(ref_name.as_str()) - && find_path(ref_name, start, scc_set, query, visited, path) - { - return true; - } - } +impl<'a> CycleFinder<'a> { + fn find( + nodes: &[String], + adj: &'a IndexMap>, + ) -> Option> { + let mut finder = Self { + adj, + visited: IndexSet::new(), + on_path: IndexMap::new(), + path: Vec::new(), + edges: Vec::new(), + }; + + for start in nodes { + if let Some(chain) = finder.dfs(start) { + return Some(chain); } + } + None + } - path.pop(); - false + fn dfs(&mut self, current: &String) -> Option> { + if self.on_path.contains_key(current) { + return None; } - find_path(start, start, &scc_set, self, &mut visited, &mut path); + if self.visited.contains(current) { + return None; + } - path.iter() - .enumerate() - .filter_map(|(i, from)| { - let to = &path[(i + 1) % path.len()]; - self.find_reference_location(from, to).map(|range| { - let msg = if i == path.len() - 1 { - format!("references {} (completing cycle)", to) - } else { - format!("references {}", to) - }; - (range, msg) - }) - }) - .collect() - } + self.visited.insert(current.clone()); + self.on_path.insert(current.clone(), self.path.len()); + self.path.push(current.clone()); + + if let Some(neighbors) = self.adj.get(current) { + for (target, range) in neighbors { + if let Some(&start_index) = self.on_path.get(target) { + // Cycle detected! + // Path: path[start_index] ... path[last] (current) + // Edges: edges[start_index] ... edges[last-1] + // Closing edge: range + let mut chain = Vec::new(); + for i in start_index..self.path.len() - 1 { + chain.push((self.edges[i], self.path[i + 1].clone())); + } + chain.push((*range, target.clone())); + return Some(chain); + } - fn build_unguarded_chain(&self, cycle: &[String]) -> Vec<(TextRange, String)> { - if cycle.len() == 1 { - return self - .find_unguarded_reference_location(&cycle[0], &cycle[0]) - .map(|range| vec![(range, "references itself".to_string())]) - .unwrap_or_default(); + self.edges.push(*range); + if let Some(chain) = self.dfs(target) { + return Some(chain); + } + self.edges.pop(); + } } - self.build_chain_generic(cycle, |from, to| { - self.find_unguarded_reference_location(from, to) - }) - } - fn build_chain_generic(&self, path_nodes: &[String], find_loc: F) -> Vec<(TextRange, String)> - where - F: Fn(&str, &str) -> Option, - { - path_nodes - .iter() - .enumerate() - .filter_map(|(i, from)| { - let to = &path_nodes[(i + 1) % path_nodes.len()]; - find_loc(from, to).map(|range| { - let msg = if i == path_nodes.len() - 1 { - format!("references {} (completing cycle)", to) - } else { - format!("references {}", to) - }; - (range, msg) - }) - }) - .collect() + self.path.pop(); + self.on_path.swap_remove(current); + None } +} - fn emit_recursion_error( - &mut self, - primary_name: &str, - scc: &[String], - related: Vec<(TextRange, String)>, - ) { - let range = related - .first() - .map(|(r, _)| *r) - .unwrap_or_else(|| TextRange::empty(0.into())); +struct SccFinder<'a, 'src> { + query: &'a Query<'src>, + index: usize, + stack: Vec, + on_stack: IndexSet, + indices: IndexMap, + lowlinks: IndexMap, + sccs: Vec>, +} - let def_range = if scc.len() > 1 { - self.find_def_by_name(primary_name) - .and_then(|def| def.name()) - .map(|n| n.text_range()) - } else { - None +impl<'a, 'src> SccFinder<'a, 'src> { + fn find(query: &'a Query<'src>) -> Vec> { + let mut finder = Self { + query, + index: 0, + stack: Vec::new(), + on_stack: IndexSet::new(), + indices: IndexMap::new(), + lowlinks: IndexMap::new(), + sccs: Vec::new(), }; - let mut builder = self - .recursion_diagnostics - .report(DiagnosticKind::RecursionNoEscape, range); - - for (rel_range, rel_msg) in related { - builder = builder.related_to(rel_msg, rel_range); - } - - if let Some(range) = def_range { - builder = builder.related_to(format!("{} is defined here", primary_name), range); + for name in query.symbol_table.keys() { + if !finder.indices.contains_key(*name) { + finder.strongconnect(name); + } } - builder.emit(); + finder + .sccs + .into_iter() + .filter(|scc| { + scc.len() > 1 + || query + .symbol_table + .get(scc[0].as_str()) + .map(|body| collect_refs(body).contains(scc[0].as_str())) + .unwrap_or(false) + }) + .collect() } - fn emit_direct_recursion_error( - &mut self, - primary_name: &str, - scc: &[String], - related: Vec<(TextRange, String)>, - ) { - let range = related - .first() - .map(|(r, _)| *r) - .unwrap_or_else(|| TextRange::empty(0.into())); - let def_range = if scc.len() > 1 { - self.find_def_by_name(primary_name) - .and_then(|def| def.name()) - .map(|n| n.text_range()) - } else { - None - }; - - let mut builder = self - .recursion_diagnostics - .report(DiagnosticKind::DirectRecursion, range); + fn strongconnect(&mut self, name: &str) { + self.indices.insert(name.to_string(), self.index); + self.lowlinks.insert(name.to_string(), self.index); + self.index += 1; + self.stack.push(name.to_string()); + self.on_stack.insert(name.to_string()); + + if let Some(body) = self.query.symbol_table.get(name) { + let refs = collect_refs(body); + for ref_name in refs { + if !self.query.symbol_table.contains_key(ref_name.as_str()) { + continue; + } - for (rel_range, rel_msg) in related { - builder = builder.related_to(rel_msg, rel_range); + if !self.indices.contains_key(&ref_name) { + self.strongconnect(&ref_name); + let ref_lowlink = self.lowlinks[&ref_name]; + let my_lowlink = self.lowlinks.get_mut(name).unwrap(); + *my_lowlink = (*my_lowlink).min(ref_lowlink); + } else if self.on_stack.contains(&ref_name) { + let ref_index = self.indices[&ref_name]; + let my_lowlink = self.lowlinks.get_mut(name).unwrap(); + *my_lowlink = (*my_lowlink).min(ref_index); + } + } } - if let Some(range) = def_range { - builder = builder.related_to(format!("{} is defined here", primary_name), range); + if self.lowlinks[name] == self.indices[name] { + let mut scc = Vec::new(); + loop { + let w = self.stack.pop().unwrap(); + self.on_stack.swap_remove(&w); + scc.push(w.clone()); + if w == name { + break; + } + } + self.sccs.push(scc); } - - builder.emit(); } } @@ -429,42 +396,6 @@ fn collect_refs_into(expr: &Expr, refs: &mut IndexSet) { } } -fn collect_unguarded_refs(expr: &Expr, scc: &IndexSet<&str>, refs: &mut IndexSet) { - match expr { - Expr::Ref(r) => { - if let Some(name) = r.name().filter(|n| scc.contains(n.text())) { - refs.insert(name.text().to_string()); - } - } - Expr::NamedNode(_) | Expr::AnonymousNode(_) => { - // Consumes input, so guards recursion. Do not collect refs inside. - } - Expr::AltExpr(_) => { - for c in expr.children() { - collect_unguarded_refs(&c, scc, refs); - } - } - Expr::SeqExpr(_) => { - for c in expr.children() { - collect_unguarded_refs(&c, scc, refs); - if expr_guarantees_consumption(&c) { - break; - } - } - } - Expr::QuantifiedExpr(q) => { - if let Some(inner) = q.inner() { - collect_unguarded_refs(&inner, scc, refs); - } - } - Expr::CapturedExpr(_) | Expr::FieldExpr(_) => { - for c in expr.children() { - collect_unguarded_refs(&c, scc, refs); - } - } - } -} - fn find_ref_in_expr(expr: &Expr, target: &str) -> Option { if let Expr::Ref(r) = expr { let name_token = r.name()?;