Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 32 additions & 38 deletions crates/plotnik-lib/src/query/dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,19 @@ fn collect_refs<'a>(expr: &Expr, symbol_table: &'a SymbolTable) -> IndexSet<&'a
refs
}

/// Whether to search for any reference or only unguarded ones.
#[derive(Clone, Copy, PartialEq, Eq)]
enum RefSearchMode {
/// Find any reference to the target.
Any,
/// Find only unguarded references (not inside a NamedNode/AnonymousNode).
Unguarded,
}

struct RefFinder<'a> {
target: &'a str,
found: Option<TextRange>,
mode: RefSearchMode,
}

impl Visitor for RefFinder<'_> {
Expand All @@ -528,49 +538,22 @@ impl Visitor for RefFinder<'_> {
walk_expr(self, expr);
}

fn visit_ref(&mut self, r: &Ref) {
if self.found.is_some() {
return;
}
if let Some(name) = r.name()
&& name.text() == self.target
{
self.found = Some(name.text_range());
}
}
}

fn find_ref_range(expr: &Expr, target: &str) -> Option<TextRange> {
let mut visitor = RefFinder {
target,
found: None,
};
visitor.visit_expr(expr);
visitor.found
}

struct UnguardedRefFinder<'a> {
target: &'a str,
found: Option<TextRange>,
}

impl Visitor for UnguardedRefFinder<'_> {
fn visit_expr(&mut self, expr: &Expr) {
if self.found.is_some() {
return;
fn visit_named_node(&mut self, node: &NamedNode) {
if self.mode == RefSearchMode::Unguarded {
return; // Guarded: stop recursion
}
walk_expr(self, expr);
}

fn visit_named_node(&mut self, _node: &NamedNode) {
// Guarded: stop recursion
super::visitor::walk_named_node(self, node);
}

fn visit_anonymous_node(&mut self, _node: &AnonymousNode) {
// Guarded: stop recursion
// AnonymousNode has no child expressions, so nothing to walk.
// In Unguarded mode this also acts as a guard (stops recursion).
}

fn visit_ref(&mut self, r: &Ref) {
if self.found.is_some() {
return;
}
if let Some(name) = r.name()
&& name.text() == self.target
{
Expand All @@ -584,17 +567,28 @@ impl Visitor for UnguardedRefFinder<'_> {
if self.found.is_some() {
return;
}
if expr_guarantees_consumption(&child) {
if self.mode == RefSearchMode::Unguarded && expr_guarantees_consumption(&child) {
return;
}
}
}
}

fn find_ref_range(expr: &Expr, target: &str) -> Option<TextRange> {
let mut visitor = RefFinder {
target,
found: None,
mode: RefSearchMode::Any,
};
visitor.visit_expr(expr);
visitor.found
}

fn find_unguarded_ref_range(expr: &Expr, target: &str) -> Option<TextRange> {
let mut visitor = UnguardedRefFinder {
let mut visitor = RefFinder {
target,
found: None,
mode: RefSearchMode::Unguarded,
};
visitor.visit_expr(expr);
visitor.found
Expand Down