From f0a617311a570abd68edd209f21c0fc61c0e7d63 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Wed, 31 Dec 2025 13:14:24 -0300 Subject: [PATCH] refactor: consolidate compiler utilities and passes --- crates/plotnik-core/src/lib.rs | 1 + crates/plotnik-core/src/utils.rs | 115 ++++++++++++++++++ .../src/bytecode/emit/typescript.rs | 19 +-- .../plotnik-lib/src/parser/grammar/utils.rs | 31 +---- crates/plotnik-lib/src/query/dependencies.rs | 29 ++++- crates/plotnik-lib/src/query/link.rs | 74 ++++------- crates/plotnik-lib/src/query/mod.rs | 1 + crates/plotnik-lib/src/query/printer.rs | 12 +- crates/plotnik-lib/src/query/refs.rs | 72 +++++++++++ .../plotnik-lib/src/query/type_check/mod.rs | 48 ++------ 10 files changed, 257 insertions(+), 145 deletions(-) create mode 100644 crates/plotnik-core/src/utils.rs create mode 100644 crates/plotnik-lib/src/query/refs.rs diff --git a/crates/plotnik-core/src/lib.rs b/crates/plotnik-core/src/lib.rs index dd63cffb..8d3b8927 100644 --- a/crates/plotnik-core/src/lib.rs +++ b/crates/plotnik-core/src/lib.rs @@ -15,6 +15,7 @@ use std::num::NonZeroU16; mod interner; mod invariants; +pub mod utils; pub use interner::{Interner, Symbol}; diff --git a/crates/plotnik-core/src/utils.rs b/crates/plotnik-core/src/utils.rs new file mode 100644 index 00000000..b8cdd1d4 --- /dev/null +++ b/crates/plotnik-core/src/utils.rs @@ -0,0 +1,115 @@ +/// Convert snake_case or kebab-case to PascalCase. +/// +/// Normalizes words separated by `_`, `-`, or `.`. If the input is already +/// PascalCase (starts uppercase, no separators), it is returned unchanged. +/// +/// # Examples +/// ``` +/// use plotnik_core::utils::to_pascal_case; +/// assert_eq!(to_pascal_case("foo_bar"), "FooBar"); +/// assert_eq!(to_pascal_case("FOO_BAR"), "FooBar"); +/// assert_eq!(to_pascal_case("FooBar"), "FooBar"); // idempotent +/// ``` +pub fn to_pascal_case(s: &str) -> String { + fn is_separator(c: char) -> bool { + matches!(c, '_' | '-' | '.') + } + + let has_separator = s.chars().any(is_separator); + let has_lowercase = s.chars().any(|c| c.is_ascii_lowercase()); + let starts_uppercase = s.chars().next().is_some_and(|c| c.is_ascii_uppercase()); + + // Already PascalCase: starts uppercase, has lowercase, no separators + if starts_uppercase && has_lowercase && !has_separator { + return s.to_string(); + } + + let mut result = String::with_capacity(s.len()); + let mut capitalize_next = true; + for c in s.chars() { + if is_separator(c) { + capitalize_next = true; + continue; + } + if capitalize_next { + result.push(c.to_ascii_uppercase()); + capitalize_next = false; + } else { + result.push(c.to_ascii_lowercase()); + } + } + result +} + +/// Convert PascalCase or camelCase to snake_case. +/// +/// # Examples +/// ``` +/// use plotnik_core::utils::to_snake_case; +/// assert_eq!(to_snake_case("FooBar"), "foo_bar"); +/// assert_eq!(to_snake_case("fooBar"), "foo_bar"); +/// ``` +pub fn to_snake_case(s: &str) -> String { + let mut result = String::new(); + for (i, c) in s.chars().enumerate() { + if c.is_ascii_uppercase() { + if i > 0 && !result.ends_with('_') { + result.push('_'); + } + result.push(c.to_ascii_lowercase()); + } else { + result.push(c); + } + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pascal_case_from_snake() { + assert_eq!(to_pascal_case("foo_bar"), "FooBar"); + assert_eq!(to_pascal_case("foo"), "Foo"); + assert_eq!(to_pascal_case("_foo"), "Foo"); + assert_eq!(to_pascal_case("foo_"), "Foo"); + } + + #[test] + fn pascal_case_normalizes() { + assert_eq!(to_pascal_case("FOO_BAR"), "FooBar"); + assert_eq!(to_pascal_case("FOO"), "Foo"); + assert_eq!(to_pascal_case("FOOBAR"), "Foobar"); + } + + #[test] + fn pascal_case_idempotent() { + assert_eq!(to_pascal_case("FooBar"), "FooBar"); + assert_eq!(to_pascal_case("QRow"), "QRow"); + assert_eq!(to_pascal_case("Q"), "Q"); + } + + #[test] + fn pascal_case_from_kebab() { + assert_eq!(to_pascal_case("foo-bar"), "FooBar"); + assert_eq!(to_pascal_case("foo-bar-baz"), "FooBarBaz"); + } + + #[test] + fn pascal_case_from_dotted() { + assert_eq!(to_pascal_case("foo.bar"), "FooBar"); + } + + #[test] + fn snake_case_from_pascal() { + assert_eq!(to_snake_case("FooBar"), "foo_bar"); + assert_eq!(to_snake_case("Foo"), "foo"); + } + + #[test] + fn snake_case_from_camel() { + assert_eq!(to_snake_case("fooBar"), "foo_bar"); + assert_eq!(to_snake_case("fooBarBaz"), "foo_bar_baz"); + } +} diff --git a/crates/plotnik-lib/src/bytecode/emit/typescript.rs b/crates/plotnik-lib/src/bytecode/emit/typescript.rs index 263255b8..a35630b9 100644 --- a/crates/plotnik-lib/src/bytecode/emit/typescript.rs +++ b/crates/plotnik-lib/src/bytecode/emit/typescript.rs @@ -6,6 +6,8 @@ use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap, HashSet}; +use plotnik_core::utils::to_pascal_case; + use crate::bytecode::module::{Module, StringsView, TypesView}; use crate::bytecode::type_meta::{TypeDef, TypeKind}; use crate::bytecode::{EntrypointsView, QTypeId}; @@ -770,23 +772,6 @@ struct NamingContext { field_name: Option, } -fn to_pascal_case(s: &str) -> String { - let mut result = String::with_capacity(s.len()); - let mut capitalize_next = true; - - for c in s.chars() { - if c == '_' || c == '-' || c == '.' { - capitalize_next = true; - } else if capitalize_next { - result.extend(c.to_uppercase()); - capitalize_next = false; - } else { - result.push(c); - } - } - result -} - /// Emit TypeScript from a bytecode module. pub fn emit_typescript(module: &Module) -> String { TsEmitter::new(module, EmitConfig::default()).emit() diff --git a/crates/plotnik-lib/src/parser/grammar/utils.rs b/crates/plotnik-lib/src/parser/grammar/utils.rs index e4664d67..76720a9f 100644 --- a/crates/plotnik-lib/src/parser/grammar/utils.rs +++ b/crates/plotnik-lib/src/parser/grammar/utils.rs @@ -1,33 +1,4 @@ -pub(crate) fn to_snake_case(s: &str) -> String { - let mut result = String::new(); - for (i, c) in s.chars().enumerate() { - if c.is_ascii_uppercase() { - if i > 0 && !result.ends_with('_') { - result.push('_'); - } - result.push(c.to_ascii_lowercase()); - } else { - result.push(c); - } - } - result -} - -pub(crate) fn to_pascal_case(s: &str) -> String { - let mut result = String::new(); - let mut capitalize_next = true; - for c in s.chars() { - if c == '_' || c == '-' || c == '.' { - capitalize_next = true; - } else if capitalize_next { - result.push(c.to_ascii_uppercase()); - capitalize_next = false; - } else { - result.push(c.to_ascii_lowercase()); - } - } - result -} +pub(crate) use plotnik_core::utils::{to_pascal_case, to_snake_case}; pub(crate) fn capitalize_first(s: &str) -> String { assert!(!s.is_empty(), "capitalize_first: called with empty string"); diff --git a/crates/plotnik-lib/src/query/dependencies.rs b/crates/plotnik-lib/src/query/dependencies.rs index bce3c9ce..2ee73529 100644 --- a/crates/plotnik-lib/src/query/dependencies.rs +++ b/crates/plotnik-lib/src/query/dependencies.rs @@ -8,7 +8,7 @@ //! which is useful for passes that need to process dependencies before //! dependents (like type inference). -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use indexmap::{IndexMap, IndexSet}; use plotnik_core::{Interner, Symbol}; @@ -39,6 +39,12 @@ pub struct DependencyAnalysis { /// Maps DefId to definition name Symbol (indexed by DefId). def_names: Vec, + + /// Set of recursive definition names. + /// + /// A definition is recursive if it's in an SCC with >1 member, + /// or it's a single-member SCC that references itself. + recursive_defs: HashSet, } impl DependencyAnalysis { @@ -82,6 +88,14 @@ impl DependencyAnalysis { pub fn name_to_def(&self) -> &HashMap { &self.name_to_def } + + /// Returns true if this definition is recursive. + /// + /// A definition is recursive if it's part of a mutual recursion group (SCC > 1), + /// or it's a single definition that references itself. + pub fn is_recursive(&self, name: &str) -> bool { + self.recursive_defs.contains(name) + } } /// Analyze dependencies between definitions. @@ -97,8 +111,20 @@ pub fn analyze_dependencies( // Assign DefIds in SCC order (leaves first, so dependencies get lower IDs) let mut name_to_def = HashMap::new(); let mut def_names = Vec::new(); + let mut recursive_defs = HashSet::new(); for scc in &sccs { + // Mark recursive definitions + if scc.len() > 1 { + // Mutual recursion: all members are recursive + recursive_defs.extend(scc.iter().cloned()); + } else if let Some(name) = scc.first() + && let Some(body) = symbol_table.get(name) + && super::refs::contains_ref(body, name) + { + recursive_defs.insert(name.clone()); + } + for name in scc { let sym = interner.intern(name); let def_id = DefId::from_raw(def_names.len() as u32); @@ -111,6 +137,7 @@ pub fn analyze_dependencies( sccs, name_to_def, def_names, + recursive_defs, } } diff --git a/crates/plotnik-lib/src/query/link.rs b/crates/plotnik-lib/src/query/link.rs index 0b654dfb..84107723 100644 --- a/crates/plotnik-lib/src/query/link.rs +++ b/crates/plotnik-lib/src/query/link.rs @@ -1,9 +1,8 @@ //! Link pass: resolve node types and fields against tree-sitter grammar. //! -//! Three-phase approach: -//! 1. Collect and resolve all node type names (NamedNode, AnonymousNode) -//! 2. Collect and resolve all field names (FieldExpr, NegatedField) -//! 3. Validate structural constraints (field on node type, child type for field) +//! Two-phase approach: +//! 1. Resolve all symbols (node types and fields) against grammar +//! 2. Validate structural constraints (field on node type, child type for field) use std::collections::HashMap; @@ -84,14 +83,13 @@ impl<'a, 'q> Linker<'a, 'q> { } fn link(&mut self, root: &ast::Root) { - self.resolve_node_types(root); - self.resolve_fields(root); + self.resolve_symbols(root); self.validate_structure(root); } - fn resolve_node_types(&mut self, root: &ast::Root) { - let mut collector = NodeTypeCollector { linker: self }; - collector.visit(root); + fn resolve_symbols(&mut self, root: &ast::Root) { + let mut resolver = SymbolResolver { linker: self }; + resolver.visit(root); } fn resolve_named_node(&mut self, node: &NamedNode) { @@ -139,11 +137,6 @@ impl<'a, 'q> Linker<'a, 'q> { } } - fn resolve_fields(&mut self, root: &ast::Root) { - let mut collector = FieldCollector { linker: self }; - collector.visit(root); - } - fn resolve_field_by_token(&mut self, name_token: Option) { let Some(name_token) = name_token else { return; @@ -403,17 +396,23 @@ struct ValidationContext { parent_range: TextRange, } -struct NodeTypeCollector<'l, 'a, 'q> { +/// Combined symbol resolver for node types and fields. +struct SymbolResolver<'l, 'a, 'q> { linker: &'l mut Linker<'a, 'q>, } -impl Visitor for NodeTypeCollector<'_, '_, '_> { +impl Visitor for SymbolResolver<'_, '_, '_> { fn visit(&mut self, root: &ast::Root) { walk(self, root); } fn visit_named_node(&mut self, node: &ast::NamedNode) { self.linker.resolve_named_node(node); + + for neg in node.as_cst().children().filter_map(ast::NegatedField::cast) { + self.linker.resolve_field_by_token(neg.name()); + } + super::visitor::walk_named_node(self, node); } @@ -433,47 +432,26 @@ impl Visitor for NodeTypeCollector<'_, '_, '_> { self.linker .node_type_ids .insert(token_src(&value_token, self.linker.source()), resolved); + if let Some(id) = resolved { let sym = self.linker.interner.intern(value); self.linker.output.node_type_ids.entry(sym).or_insert(id); + return; } - if resolved.is_none() { - self.linker - .diagnostics - .report( - self.linker.source_id, - DiagnosticKind::UnknownNodeType, - value_token.text_range(), - ) - .message(value) - .emit(); - } - } -} - -struct FieldCollector<'l, 'a, 'q> { - linker: &'l mut Linker<'a, 'q>, -} - -impl Visitor for FieldCollector<'_, '_, '_> { - fn visit(&mut self, root: &ast::Root) { - walk(self, root); - } - - fn visit_named_node(&mut self, node: &ast::NamedNode) { - for child in node.as_cst().children() { - if let Some(neg) = ast::NegatedField::cast(child) { - self.linker.resolve_field_by_token(neg.name()); - } - } - - super::visitor::walk_named_node(self, node); + self.linker + .diagnostics + .report( + self.linker.source_id, + DiagnosticKind::UnknownNodeType, + value_token.text_range(), + ) + .message(value) + .emit(); } fn visit_field_expr(&mut self, field: &ast::FieldExpr) { self.linker.resolve_field_by_token(field.name()); - super::visitor::walk_field_expr(self, field); } } diff --git a/crates/plotnik-lib/src/query/mod.rs b/crates/plotnik-lib/src/query/mod.rs index 39dc0fdb..703da0cd 100644 --- a/crates/plotnik-lib/src/query/mod.rs +++ b/crates/plotnik-lib/src/query/mod.rs @@ -1,6 +1,7 @@ mod dump; mod invariants; mod printer; +mod refs; mod source_map; mod utils; pub use printer::QueryPrinter; diff --git a/crates/plotnik-lib/src/query/printer.rs b/crates/plotnik-lib/src/query/printer.rs index 53f6cf7e..3386dee5 100644 --- a/crates/plotnik-lib/src/query/printer.rs +++ b/crates/plotnik-lib/src/query/printer.rs @@ -5,7 +5,7 @@ use std::fmt::Write; use indexmap::IndexSet; use rowan::NodeOrToken; -use crate::parser::{self as ast, Expr, SyntaxNode}; +use crate::parser::{self as ast, SyntaxNode}; use super::Query; use super::source_map::SourceKind; @@ -167,7 +167,7 @@ impl<'q> QueryPrinter<'q> { visited.insert(name.to_string()); if let Some(body) = self.query.symbol_table.get(name) { - let refs_set = collect_refs(body); + let refs_set = super::refs::collect_ref_names(body); let mut refs: Vec<_> = refs_set.iter().map(|s| s.as_str()).collect(); refs.sort(); for r in refs { @@ -413,11 +413,3 @@ impl Query { } } -fn collect_refs(expr: &Expr) -> IndexSet { - expr.as_cst() - .descendants() - .filter_map(ast::Ref::cast) - .filter_map(|r| r.name()) - .map(|tok| tok.text().to_string()) - .collect() -} diff --git a/crates/plotnik-lib/src/query/refs.rs b/crates/plotnik-lib/src/query/refs.rs new file mode 100644 index 00000000..89285516 --- /dev/null +++ b/crates/plotnik-lib/src/query/refs.rs @@ -0,0 +1,72 @@ +//! Utilities for working with definition references in expressions. + +use indexmap::IndexSet; + +use crate::parser::ast::{self, Expr}; + +/// Iterate over all Ref nodes in an expression tree. +pub fn ref_nodes(expr: &Expr) -> impl Iterator + '_ { + expr.as_cst().descendants().filter_map(ast::Ref::cast) +} + +/// Collect all reference names as owned strings. +pub fn collect_ref_names(expr: &Expr) -> IndexSet { + ref_nodes(expr) + .filter_map(|r| r.name()) + .map(|tok| tok.text().to_string()) + .collect() +} + +/// Check if expression contains a reference to the given name. +pub fn contains_ref(expr: &Expr, name: &str) -> bool { + ref_nodes(expr) + .filter_map(|r| r.name()) + .any(|tok| tok.text() == name) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Query; + + #[test] + fn collect_refs_from_simple_ref() { + let q = Query::expect("Q = (Foo)"); + let expr = q.symbol_table.get("Q").unwrap(); + let refs = collect_ref_names(expr); + assert_eq!(refs.len(), 1); + assert!(refs.contains("Foo")); + } + + #[test] + fn collect_refs_from_nested() { + let q = Query::expect("Q = (x (Foo) (Bar))"); + let expr = q.symbol_table.get("Q").unwrap(); + let refs = collect_ref_names(expr); + assert_eq!(refs.len(), 2); + assert!(refs.contains("Foo")); + assert!(refs.contains("Bar")); + } + + #[test] + fn collect_refs_deduplicates() { + let q = Query::expect("Q = {(Foo) (Foo)}"); + let expr = q.symbol_table.get("Q").unwrap(); + let refs = collect_ref_names(expr); + assert_eq!(refs.len(), 1); + } + + #[test] + fn contains_ref_positive() { + let q = Query::expect("Q = (x (Foo))"); + let expr = q.symbol_table.get("Q").unwrap(); + assert!(contains_ref(expr, "Foo")); + } + + #[test] + fn contains_ref_negative() { + let q = Query::expect("Q = (x (Foo))"); + let expr = q.symbol_table.get("Q").unwrap(); + assert!(!contains_ref(expr, "Bar")); + } +} diff --git a/crates/plotnik-lib/src/query/type_check/mod.rs b/crates/plotnik-lib/src/query/type_check/mod.rs index 2010af76..66932018 100644 --- a/crates/plotnik-lib/src/query/type_check/mod.rs +++ b/crates/plotnik-lib/src/query/type_check/mod.rs @@ -25,7 +25,7 @@ use std::collections::BTreeMap; use indexmap::IndexMap; use crate::diagnostics::Diagnostics; -use crate::parser::ast::{self, Root}; +use crate::parser::ast::Root; use crate::query::dependencies::DependencyAnalysis; use crate::query::source_map::SourceId; use crate::query::symbol_table::{SymbolTable, UNNAMED_DEF}; @@ -80,16 +80,17 @@ impl<'a> InferencePass<'a> { } /// Identify and mark recursive definitions. - /// A def is recursive if it's in an SCC with >1 member, or it references itself directly. fn mark_recursion(&mut self) { for scc in &self.dependency_analysis.sccs { - if self.is_scc_recursive(scc) { - for def_name in scc { - let sym = self.interner.intern(def_name); - if let Some(def_id) = self.ctx.get_def_id_sym(sym) { - self.ctx.mark_recursive(def_id); - } + for def_name in scc { + if !self.dependency_analysis.is_recursive(def_name) { + continue; } + let sym = self.interner.intern(def_name); + let Some(def_id) = self.ctx.get_def_id_sym(sym) else { + continue; + }; + self.ctx.mark_recursive(def_id); } } } @@ -140,22 +141,6 @@ impl<'a> InferencePass<'a> { } } - fn is_scc_recursive(&self, scc: &[String]) -> bool { - if scc.len() > 1 { - return true; - } - - let Some(name) = scc.first() else { - return false; - }; - - let Some(body) = self.symbol_table.get(name) else { - return false; - }; - - body_references_self(body, name) - } - fn flow_to_type_id(&mut self, flow: &TypeFlow) -> TypeId { match flow { TypeFlow::Void => self.ctx.intern_struct(BTreeMap::new()), @@ -164,21 +149,6 @@ impl<'a> InferencePass<'a> { } } -/// Check if an expression body contains a reference to the given name. -fn body_references_self(body: &ast::Expr, name: &str) -> bool { - body.as_cst().descendants().any(|descendant| { - let Some(r) = ast::Ref::cast(descendant) else { - return false; - }; - - let Some(name_tok) = r.name() else { - return false; - }; - - name_tok.text() == name - }) -} - /// Get the primary definition name (first non-underscore, or underscore if none). pub fn primary_def_name(symbol_table: &SymbolTable) -> &str { for name in symbol_table.keys() {