diff --git a/AGENTS.md b/AGENTS.md index 365da2b5..cea75848 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -34,9 +34,19 @@ crates/ recursion.rs # Escape analysis (recursion validation) shapes.rs # Shape inference *_tests.rs # Test files per module + infer/ # Type inference and emission + mod.rs # Re-exports, TypePrinter builder + types.rs # Type IR (TypeDef, Field, etc.) + tyton.rs # Tyton → TypeDef conversion + printer.rs # TypePrinter API + emit/ # Language-specific emitters + mod.rs # Emitter trait, common utilities + rust.rs # Rust type emission + typescript.rs # TypeScript type emission + *_tests.rs # Test files per module lib.rs # Re-exports Query, Diagnostics, Error plotnik-cli/ # CLI tool - src/commands/ # Subcommands (debug, docs, langs) + src/commands/ # Subcommands (debug, docs, infer, langs) plotnik-langs/ # Tree-sitter language bindings docs/ REFERENCE.md # Language specification @@ -59,6 +69,7 @@ Module = "what", function = "action". Run: `cargo run -p plotnik-cli -- ` - `debug` — Inspect queries/sources +- `infer` — Generate type definitions from queries - `docs [topic]` — Print docs (reference, examples) - `langs` — List supported languages @@ -76,6 +87,27 @@ cargo run -p plotnik-cli -- debug -s app.ts --raw cargo run -p plotnik-cli -- debug -q '(function_declaration) @fn' -s app.ts -l typescript ``` +### infer options + +Input: `-q/--query `, `--query-file ` + +Output language: `-l/--lang ` + +Common: `--entry-name `, `--color ` + +Rust-specific: `--indirection `, `--derive `, `--no-derive` + +TypeScript-specific: `--optional `, `--export`, `--readonly`, `--type-alias`, `--node-type `, `--nested` + +```sh +cargo run -p plotnik-cli -- infer -q '(identifier) @id' -l rust +cargo run -p plotnik-cli -- infer -q '(fn)' -l rust --derive debug,clone +cargo run -p plotnik-cli -- infer -q '(fn)' -l rust --no-derive +cargo run -p plotnik-cli -- infer -q '(identifier)' -l ts --export +cargo run -p plotnik-cli -- infer -q '(identifier)?' -l ts --optional undefined +cargo run -p plotnik-cli -- infer -q '(fn)' -l ts --readonly --type-alias +``` + ## Syntax Grammar: `(type)`, `[a b]` (alt), `{a b}` (seq), `_` (wildcard), `@name`, `::Type`, `field:`, `*+?`, `"lit"`/`'lit'`, `(a/b)` (supertype), `(ERROR)`, `Name = expr` (def), `[A: ... B: ...]` (tagged alt) diff --git a/crates/plotnik-cli/src/cli.rs b/crates/plotnik-cli/src/cli.rs index 395fad1f..55599c86 100644 --- a/crates/plotnik-cli/src/cli.rs +++ b/crates/plotnik-cli/src/cli.rs @@ -2,6 +2,31 @@ use std::path::PathBuf; use clap::{Args, Parser, Subcommand, ValueEnum}; +#[derive(Clone, Copy, Debug, Default, ValueEnum)] +pub enum OutputLang { + #[default] + Rust, + Typescript, + Ts, +} + +#[derive(Clone, Copy, Debug, Default, ValueEnum)] +pub enum IndirectionChoice { + #[default] + Box, + Rc, + Arc, +} + +#[derive(Clone, Copy, Debug, Default, ValueEnum)] +pub enum OptionalChoice { + #[default] + Null, + Undefined, + #[value(name = "questionmark")] + QuestionMark, +} + #[derive(Clone, Copy, Debug, Default, ValueEnum)] pub enum ColorChoice { #[default] @@ -52,6 +77,29 @@ pub enum Command { output: OutputArgs, }, + /// Infer and emit types from a query + #[command(after_help = r#"EXAMPLES: + plotnik infer -q '(identifier) @id' -l rust + plotnik infer -q '(function_declaration name: (identifier) @name) @fn' -l ts --export + plotnik infer --query-file query.plot -l rust --derive debug,clone,partialeq"#)] + Infer { + #[command(flatten)] + query: QueryArgs, + + /// Output language + #[arg(short = 'l', long, value_name = "LANG")] + lang: OutputLang, + + #[command(flatten)] + common: InferCommonArgs, + + #[command(flatten)] + rust: RustArgs, + + #[command(flatten)] + typescript: TypeScriptArgs, + }, + /// Print documentation Docs { /// Topic to display (e.g., "reference", "examples") @@ -112,3 +160,56 @@ pub struct OutputArgs { #[arg(long)] pub cardinalities: bool, } + +#[derive(Args)] +pub struct InferCommonArgs { + /// Name for the entry point type (default: QueryResult) + #[arg(long, value_name = "NAME")] + pub entry_name: Option, + + /// Colorize diagnostics output + #[arg(long, default_value = "auto", value_name = "WHEN")] + pub color: ColorChoice, +} + +#[derive(Args)] +pub struct RustArgs { + /// Indirection type for cyclic references + #[arg(long, value_name = "TYPE")] + pub indirection: Option, + + /// Derive macros (comma-separated: debug, clone, partialeq) + #[arg(long, value_name = "TRAITS", value_delimiter = ',')] + pub derive: Option>, + + /// Emit no derive macros + #[arg(long)] + pub no_derive: bool, +} + +#[derive(Args)] +pub struct TypeScriptArgs { + /// How to represent optional values + #[arg(long, value_name = "STYLE")] + pub optional: Option, + + /// Add export keyword to types + #[arg(long)] + pub export: bool, + + /// Make fields readonly + #[arg(long)] + pub readonly: bool, + + /// Use type aliases instead of interfaces + #[arg(long)] + pub type_alias: bool, + + /// Name for the Node type (default: SyntaxNode) + #[arg(long, value_name = "NAME")] + pub node_type: Option, + + /// Emit nested synthetic types instead of inlining + #[arg(long)] + pub nested: bool, +} diff --git a/crates/plotnik-cli/src/commands/infer.rs b/crates/plotnik-cli/src/commands/infer.rs new file mode 100644 index 00000000..7c119793 --- /dev/null +++ b/crates/plotnik-cli/src/commands/infer.rs @@ -0,0 +1,146 @@ +use std::fs; +use std::io::{self, Read}; + +use plotnik_lib::Query; +use plotnik_lib::infer::{Indirection, OptionalStyle}; + +use crate::cli::{IndirectionChoice, OptionalChoice, OutputLang}; + +pub struct InferArgs { + pub query_text: Option, + pub query_file: Option, + pub lang: OutputLang, + pub entry_name: Option, + pub color: bool, + // Rust options + pub indirection: Option, + pub derive: Option>, + pub no_derive: bool, + // TypeScript options + pub optional: Option, + pub export: bool, + pub readonly: bool, + pub type_alias: bool, + pub node_type: Option, + pub nested: bool, +} + +pub fn run(args: InferArgs) { + if let Err(msg) = validate(&args) { + eprintln!("error: {}", msg); + std::process::exit(1); + } + + let query_source = load_query(&args); + + let query = Query::try_from(query_source.as_str()).unwrap_or_else(|e| { + eprintln!("error: {}", e); + std::process::exit(1); + }); + + if !query.is_valid() { + eprint!( + "{}", + query + .diagnostics() + .render_colored(&query_source, args.color) + ); + std::process::exit(1); + } + + let output = emit_types(&query, &args); + println!("{}", output); + + if query.diagnostics().has_warnings() { + eprint!( + "{}", + query + .diagnostics() + .render_colored(&query_source, args.color) + ); + } +} + +fn validate(args: &InferArgs) -> Result<(), &'static str> { + if args.query_text.is_none() && args.query_file.is_none() { + return Err("query input required: -q/--query or --query-file"); + } + + Ok(()) +} + +fn load_query(args: &InferArgs) -> String { + if let Some(ref text) = args.query_text { + return text.clone(); + } + if let Some(ref path) = args.query_file { + if path.as_os_str() == "-" { + let mut buf = String::new(); + io::stdin() + .read_to_string(&mut buf) + .expect("failed to read stdin"); + return buf; + } + return fs::read_to_string(path).expect("failed to read query file"); + } + unreachable!() +} + +fn emit_types(query: &Query<'_>, args: &InferArgs) -> String { + let mut printer = query.type_printer(); + + if let Some(ref name) = args.entry_name { + printer = printer.entry_name(name); + } + + match args.lang { + OutputLang::Rust => emit_rust(printer, args), + OutputLang::Typescript | OutputLang::Ts => emit_typescript(printer, args), + } +} + +fn emit_rust(printer: plotnik_lib::infer::TypePrinter<'_>, args: &InferArgs) -> String { + let mut rust = printer.rust(); + + if let Some(ind) = args.indirection { + let indirection = match ind { + IndirectionChoice::Box => Indirection::Box, + IndirectionChoice::Rc => Indirection::Rc, + IndirectionChoice::Arc => Indirection::Arc, + }; + rust = rust.indirection(indirection); + } + + if args.no_derive { + rust = rust.derive(&[]); + } else if let Some(ref traits) = args.derive { + let trait_refs: Vec<&str> = traits.iter().map(|s| s.as_str()).collect(); + rust = rust.derive(&trait_refs); + } + + rust.render() +} + +fn emit_typescript(printer: plotnik_lib::infer::TypePrinter<'_>, args: &InferArgs) -> String { + let mut ts = printer.typescript(); + + if let Some(opt) = args.optional { + let style = match opt { + OptionalChoice::Null => OptionalStyle::Null, + OptionalChoice::Undefined => OptionalStyle::Undefined, + OptionalChoice::QuestionMark => OptionalStyle::QuestionMark, + }; + ts = ts.optional(style); + } + + ts = ts.export(args.export); + ts = ts.readonly(args.readonly); + ts = ts.type_alias(args.type_alias); + ts = ts.nested(args.nested); + + if let Some(ref name) = args.node_type { + ts = ts.node_type(name); + } + + ts.render() +} diff --git a/crates/plotnik-cli/src/commands/mod.rs b/crates/plotnik-cli/src/commands/mod.rs index 37b04dfb..f33f5594 100644 --- a/crates/plotnik-cli/src/commands/mod.rs +++ b/crates/plotnik-cli/src/commands/mod.rs @@ -1,3 +1,4 @@ pub mod debug; pub mod docs; +pub mod infer; pub mod langs; diff --git a/crates/plotnik-cli/src/main.rs b/crates/plotnik-cli/src/main.rs index b67e3465..41693a23 100644 --- a/crates/plotnik-cli/src/main.rs +++ b/crates/plotnik-cli/src/main.rs @@ -3,6 +3,7 @@ mod commands; use cli::{Cli, Command}; use commands::debug::DebugArgs; +use commands::infer::InferArgs; fn main() { let cli = ::parse(); @@ -28,6 +29,30 @@ fn main() { color: output.color.should_colorize(), }); } + Command::Infer { + query, + lang, + common, + rust, + typescript, + } => { + commands::infer::run(InferArgs { + query_text: query.query_text, + query_file: query.query_file, + lang, + entry_name: common.entry_name, + color: common.color.should_colorize(), + indirection: rust.indirection, + derive: rust.derive, + no_derive: rust.no_derive, + optional: typescript.optional, + export: typescript.export, + readonly: typescript.readonly, + type_alias: typescript.type_alias, + node_type: typescript.node_type, + nested: typescript.nested, + }); + } Command::Docs { topic } => { commands::docs::run(topic.as_deref()); } diff --git a/crates/plotnik-core/src/lib.rs b/crates/plotnik-core/src/lib.rs index 99dd2988..e3ab0a97 100644 --- a/crates/plotnik-core/src/lib.rs +++ b/crates/plotnik-core/src/lib.rs @@ -15,10 +15,6 @@ use std::num::NonZeroU16; mod invariants; -// ============================================================================ -// Deserialization Layer -// ============================================================================ - /// Raw node definition from `node-types.json`. #[derive(Debug, Clone, serde::Deserialize)] pub struct RawNode { @@ -56,10 +52,6 @@ pub fn parse_node_types(json: &str) -> Result, serde_json::Error> { serde_json::from_str(json) } -// ============================================================================ -// Common Types -// ============================================================================ - /// Node type ID (tree-sitter uses u16). pub type NodeTypeId = u16; @@ -73,10 +65,6 @@ pub struct Cardinality { pub required: bool, } -// ============================================================================ -// NodeTypes Trait -// ============================================================================ - /// Trait for node type constraint lookups. /// /// Provides only what tree-sitter's `Language` API doesn't: @@ -156,10 +144,6 @@ impl NodeTypes for &T { } } -// ============================================================================ -// Static Analysis Layer (zero runtime init) -// ============================================================================ - /// Field info for static storage. #[derive(Debug, Clone, Copy)] pub struct StaticFieldInfo { @@ -325,10 +309,6 @@ impl NodeTypes for StaticNodeTypes { } } -// ============================================================================ -// Dynamic Analysis Layer (runtime construction) -// ============================================================================ - /// Information about a single field on a node type. #[derive(Debug, Clone)] pub struct FieldInfo { diff --git a/crates/plotnik-lib/src/diagnostics/message.rs b/crates/plotnik-lib/src/diagnostics/message.rs index dfad769a..5147b6ac 100644 --- a/crates/plotnik-lib/src/diagnostics/message.rs +++ b/crates/plotnik-lib/src/diagnostics/message.rs @@ -55,11 +55,17 @@ pub enum DiagnosticKind { // Valid syntax, invalid semantics DuplicateDefinition, + DuplicateCaptureInScope, UndefinedReference, MixedAltBranches, RecursionNoEscape, FieldSequenceValue, + // Type inference errors + TypeConflictInMerge, + MergeAltRequiresAnnotation, + IncompatibleTaggedAlternations, + // Link pass - grammar validation UnknownNodeType, UnknownField, @@ -159,11 +165,21 @@ impl DiagnosticKind { // Semantic errors Self::DuplicateDefinition => "name already defined", + Self::DuplicateCaptureInScope => "duplicate capture in same scope", Self::UndefinedReference => "undefined reference", Self::MixedAltBranches => "cannot mix labeled and unlabeled branches", Self::RecursionNoEscape => "infinite recursion detected", Self::FieldSequenceValue => "field must match exactly one node", + // Type inference errors + Self::TypeConflictInMerge => "capture has conflicting types across branches", + Self::MergeAltRequiresAnnotation => { + "merged alternation with captures requires type annotation" + } + Self::IncompatibleTaggedAlternations => { + "tagged alternations with different variants cannot be merged" + } + // Link pass - grammar validation Self::UnknownNodeType => "unknown node type", Self::UnknownField => "unknown field", @@ -189,6 +205,7 @@ impl DiagnosticKind { // Semantic errors with name context Self::DuplicateDefinition => "`{}` is already defined".to_string(), + Self::DuplicateCaptureInScope => "capture `@{}` already used in this scope".to_string(), Self::UndefinedReference => "`{}` is not defined".to_string(), // Link pass errors with context @@ -201,6 +218,14 @@ impl DiagnosticKind { // Recursion with cycle path Self::RecursionNoEscape => "infinite recursion: {}".to_string(), + // Type inference + Self::TypeConflictInMerge => { + "capture `{}` has conflicting types across branches".to_string() + } + Self::MergeAltRequiresAnnotation => { + "merged alternation requires `:: {}` type annotation".to_string() + } + // Alternation mixing Self::MixedAltBranches => "cannot mix labeled and unlabeled branches: {}".to_string(), diff --git a/crates/plotnik-lib/src/diagnostics/tests.rs b/crates/plotnik-lib/src/diagnostics/tests.rs index 0f921aeb..92f30d7b 100644 --- a/crates/plotnik-lib/src/diagnostics/tests.rs +++ b/crates/plotnik-lib/src/diagnostics/tests.rs @@ -340,8 +340,6 @@ fn diagnostic_kind_message_rendering() { ); } -// === Filtering/suppression tests === - #[test] fn filtered_no_suppression_disjoint_spans() { let mut diagnostics = Diagnostics::new(); diff --git a/crates/plotnik-lib/src/infer/emit/rust.rs b/crates/plotnik-lib/src/infer/emit/rust.rs index b3680273..a91e12e3 100644 --- a/crates/plotnik-lib/src/infer/emit/rust.rs +++ b/crates/plotnik-lib/src/infer/emit/rust.rs @@ -4,24 +4,47 @@ use indexmap::IndexMap; +/// Rust keywords that must be escaped with `r#` prefix when used as identifiers. +const RUST_KEYWORDS: &[&str] = &[ + // Strict keywords + "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", + "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", + "unsafe", "use", "where", "while", // Reserved keywords + "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try", "typeof", + "unsized", "virtual", "yield", +]; + +/// Escape a name if it's a Rust keyword by prefixing with `r#`. +fn escape_keyword(name: &str) -> String { + if RUST_KEYWORDS.contains(&name) { + format!("r#{}", name) + } else { + name.to_string() + } +} + use super::super::types::{TypeKey, TypeTable, TypeValue}; /// Configuration for Rust emission. #[derive(Debug, Clone)] pub struct RustEmitConfig { + /// Name for the entry point type (default: "QueryResult"). + pub entry_name: String, /// Indirection type for cyclic references. pub indirection: Indirection, - /// Whether to derive common traits. + /// Whether to derive Debug. pub derive_debug: bool, + /// Whether to derive Clone. pub derive_clone: bool, + /// Whether to derive PartialEq. pub derive_partial_eq: bool, - /// Name for the default (unnamed) query entry point type. - pub default_query_name: String, } /// How to handle cyclic type references. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum Indirection { + #[default] Box, Rc, Arc, @@ -30,11 +53,11 @@ pub enum Indirection { impl Default for RustEmitConfig { fn default() -> Self { Self { + entry_name: "QueryResult".to_string(), indirection: Indirection::Box, derive_debug: true, derive_clone: true, derive_partial_eq: false, - default_query_name: "QueryResult".to_string(), } } } @@ -70,10 +93,7 @@ fn emit_type_def( table: &TypeTable<'_>, config: &RustEmitConfig, ) -> String { - let name = match key { - TypeKey::DefaultQuery => config.default_query_name.clone(), - _ => key.to_pascal_case(), - }; + let name = key.to_pascal_case_with_entry_name(&config.entry_name); match value { TypeValue::Node | TypeValue::String | TypeValue::Unit | TypeValue::Invalid => String::new(), @@ -86,7 +106,8 @@ fn emit_type_def( out.push_str(&format!("pub struct {} {{\n", name)); for (field_name, field_type) in fields { let type_str = emit_type_ref(field_type, table, config); - out.push_str(&format!(" pub {}: {},\n", field_name, type_str)); + let escaped_name = escape_keyword(field_name); + out.push_str(&format!(" pub {}: {},\n", escaped_name, type_str)); } out.push('}'); } @@ -107,7 +128,8 @@ fn emit_type_def( out.push_str(&format!(" {} {{\n", variant_name)); for (field_name, field_type) in f { let type_str = emit_type_ref(field_type, table, config); - out.push_str(&format!(" {}: {},\n", field_name, type_str)); + let escaped_name = escape_keyword(field_name); + out.push_str(&format!(" {}: {},\n", escaped_name, type_str)); } out.push_str(" },\n"); } @@ -154,10 +176,9 @@ pub(crate) fn emit_type_ref( format!("Vec<{}>", inner_str) } // Struct, TaggedUnion, or undefined forward reference - use pascal-cased name - Some(TypeValue::Struct(_)) | Some(TypeValue::TaggedUnion(_)) | None => match key { - TypeKey::DefaultQuery => config.default_query_name.clone(), - _ => key.to_pascal_case(), - }, + Some(TypeValue::Struct(_)) | Some(TypeValue::TaggedUnion(_)) | None => { + key.to_pascal_case_with_entry_name(&config.entry_name) + } }; if is_cyclic { diff --git a/crates/plotnik-lib/src/infer/emit/rust_tests.rs b/crates/plotnik-lib/src/infer/emit/rust_tests.rs index 932d64a1..9d5de72a 100644 --- a/crates/plotnik-lib/src/infer/emit/rust_tests.rs +++ b/crates/plotnik-lib/src/infer/emit/rust_tests.rs @@ -20,8 +20,6 @@ fn emit_cyclic(input: &str, cyclic_types: &[&str]) -> String { emit_rust(&table, &RustEmitConfig::default()) } -// --- Simple Structs --- - #[test] fn emit_struct_single_field() { let input = "Foo = { #Node @value }"; @@ -86,8 +84,6 @@ fn emit_struct_nested_refs() { "); } -// --- Tagged Unions --- - #[test] fn emit_tagged_union_simple() { let input = indoc! {r#" @@ -168,8 +164,6 @@ fn emit_tagged_union_with_builtins() { "); } -// --- Wrapper Types --- - #[test] fn emit_optional() { let input = "MaybeNode = #Node?"; @@ -239,8 +233,6 @@ fn emit_nested_wrappers() { "); } -// --- Cyclic Types --- - #[test] fn emit_cyclic_box() { let input = indoc! {r#" @@ -292,8 +284,6 @@ fn emit_cyclic_arc() { "); } -// --- Config Variations --- - #[test] fn emit_no_derives() { let input = "Foo = { #Node @value }"; @@ -344,8 +334,6 @@ fn emit_all_derives() { "); } -// --- Complex Scenarios --- - #[test] fn emit_complex_program() { let input = indoc! {r#" @@ -455,8 +443,6 @@ fn emit_mixed_wrappers_and_structs() { "); } -// --- Edge Cases --- - #[test] fn emit_single_variant_union() { let input = indoc! {r#" @@ -538,8 +524,6 @@ fn emit_builtin_value_with_named_key() { insta::assert_snapshot!(emit(input), @""); } -// --- DefaultQuery --- - #[test] fn emit_default_query_struct() { let input = "#DefaultQuery = { #Node @value }"; @@ -556,7 +540,7 @@ fn emit_default_query_struct() { fn emit_default_query_custom_name() { let input = "#DefaultQuery = { #Node @value }"; let config = RustEmitConfig { - default_query_name: "MyResult".to_string(), + entry_name: "MyResult".to_string(), ..Default::default() }; @@ -590,3 +574,46 @@ fn emit_default_query_referenced() { } "); } + +#[test] +fn emit_struct_with_keyword_fields() { + let input = "Foo = { #Node @type #Node @fn #Node @match }"; + insta::assert_snapshot!(emit(input), @r" + #[derive(Debug, Clone)] + pub struct Foo { + pub r#type: Node, + pub r#fn: Node, + pub r#match: Node, + } + "); +} + +#[test] +fn emit_keyword_field_in_enum() { + let input = indoc! {r#" + TypeVariant = { #Node @type } + FnVariant = { #Node @fn } + E = [ Type: TypeVariant Fn: FnVariant ] + "#}; + insta::assert_snapshot!(emit(input), @r" + #[derive(Debug, Clone)] + pub struct TypeVariant { + pub r#type: Node, + } + + #[derive(Debug, Clone)] + pub struct FnVariant { + pub r#fn: Node, + } + + #[derive(Debug, Clone)] + pub enum E { + Type { + r#type: Node, + }, + Fn { + r#fn: Node, + }, + } + "); +} diff --git a/crates/plotnik-lib/src/infer/emit/typescript.rs b/crates/plotnik-lib/src/infer/emit/typescript.rs index 72621fd1..20804d40 100644 --- a/crates/plotnik-lib/src/infer/emit/typescript.rs +++ b/crates/plotnik-lib/src/infer/emit/typescript.rs @@ -10,19 +10,19 @@ use super::super::types::{TypeKey, TypeTable, TypeValue}; #[derive(Debug, Clone)] pub struct TypeScriptEmitConfig { /// How to represent optional values. - pub optional_style: OptionalStyle, + pub optional: OptionalStyle, /// Whether to export types. pub export: bool, /// Whether to make fields readonly. pub readonly: bool, - /// Whether to inline synthetic types. - pub inline_synthetic: bool, + /// Whether to emit nested synthetic types instead of inlining them. + pub nested: bool, /// Name for the Node type. - pub node_type_name: String, + pub node_type: String, /// Whether to emit `type Foo = ...` instead of `interface Foo { ... }`. - pub use_type_alias: bool, + pub type_alias: bool, /// Name for the default (unnamed) query entry point. - pub default_query_name: String, + pub entry_name: String, } /// How to represent optional types. @@ -39,13 +39,13 @@ pub enum OptionalStyle { impl Default for TypeScriptEmitConfig { fn default() -> Self { Self { - optional_style: OptionalStyle::Null, + optional: OptionalStyle::Null, export: false, readonly: false, - inline_synthetic: true, - node_type_name: "SyntaxNode".to_string(), - use_type_alias: false, - default_query_name: "QueryResult".to_string(), + nested: false, + node_type: "SyntaxNode".to_string(), + type_alias: false, + entry_name: "QueryResult".to_string(), } } } @@ -65,8 +65,8 @@ pub fn emit_typescript(table: &TypeTable<'_>, config: &TypeScriptEmitConfig) -> continue; } - // Skip synthetic types if inlining - if config.inline_synthetic && matches!(key, TypeKey::Synthetic(_)) { + // Skip synthetic types if not nested (i.e., inlining) + if !config.nested && matches!(key, TypeKey::Synthetic { .. }) { continue; } @@ -87,7 +87,7 @@ fn emit_type_def( config: &TypeScriptEmitConfig, ) -> String { let name = type_name(key, config); - let export_prefix = if config.export && !matches!(key, TypeKey::Synthetic(_)) { + let export_prefix = if config.export && !matches!(key, TypeKey::Synthetic { .. }) { "export " } else { "" @@ -97,7 +97,7 @@ fn emit_type_def( TypeValue::Node | TypeValue::String | TypeValue::Unit | TypeValue::Invalid => String::new(), TypeValue::Struct(fields) => { - if config.use_type_alias { + if config.type_alias { let inline = emit_inline_struct(fields, table, config); format!("{}type {} = {};", export_prefix, name, inline) } else if fields.is_empty() { @@ -107,12 +107,12 @@ fn emit_type_def( for (field_name, field_type) in fields { let (type_str, is_optional) = emit_field_type(field_type, table, config); let readonly = if config.readonly { "readonly " } else { "" }; - let optional = - if is_optional && config.optional_style == OptionalStyle::QuestionMark { - "?" - } else { - "" - }; + let optional = if is_optional && config.optional == OptionalStyle::QuestionMark + { + "?" + } else { + "" + }; out.push_str(&format!( " {}{}{}: {};\n", readonly, field_name, optional, type_str @@ -134,13 +134,12 @@ fn emit_type_def( if let Some(TypeValue::Struct(fields)) = table.get(variant_key) { for (field_name, field_type) in fields { let (type_str, is_optional) = emit_field_type(field_type, table, config); - let optional = if is_optional - && config.optional_style == OptionalStyle::QuestionMark - { - "?" - } else { - "" - }; + let optional = + if is_optional && config.optional == OptionalStyle::QuestionMark { + "?" + } else { + "" + }; out.push_str(&format!("; {}{}: {}", field_name, optional, type_str)); } } @@ -167,13 +166,13 @@ pub(crate) fn emit_field_type( config: &TypeScriptEmitConfig, ) -> (String, bool) { match table.get(key) { - Some(TypeValue::Node) => (config.node_type_name.clone(), false), + Some(TypeValue::Node) => (config.node_type.clone(), false), Some(TypeValue::String) => ("string".to_string(), false), Some(TypeValue::Unit) | Some(TypeValue::Invalid) => ("{}".to_string(), false), Some(TypeValue::Optional(inner)) => { let (inner_str, _) = emit_field_type(inner, table, config); - let type_str = match config.optional_style { + let type_str = match config.optional { OptionalStyle::Null => format!("{} | null", inner_str), OptionalStyle::Undefined => format!("{} | undefined", inner_str), OptionalStyle::QuestionMark => inner_str, @@ -192,7 +191,7 @@ pub(crate) fn emit_field_type( } Some(TypeValue::Struct(fields)) => { - if config.inline_synthetic && matches!(key, TypeKey::Synthetic(_)) { + if !config.nested && matches!(key, TypeKey::Synthetic { .. }) { (emit_inline_struct(fields, table, config), false) } else { (type_name(key, config), false) @@ -217,7 +216,7 @@ pub(crate) fn emit_inline_struct( let mut out = String::from("{ "); for (i, (field_name, field_type)) in fields.iter().enumerate() { let (type_str, is_optional) = emit_field_type(field_type, table, config); - let optional = if is_optional && config.optional_style == OptionalStyle::QuestionMark { + let optional = if is_optional && config.optional == OptionalStyle::QuestionMark { "?" } else { "" @@ -235,11 +234,7 @@ pub(crate) fn emit_inline_struct( } fn type_name(key: &TypeKey<'_>, config: &TypeScriptEmitConfig) -> String { - if key.is_default_query() { - config.default_query_name.clone() - } else { - key.to_pascal_case() - } + key.to_pascal_case_with_entry_name(&config.entry_name) } pub(crate) fn wrap_if_union(type_str: &str) -> String { diff --git a/crates/plotnik-lib/src/infer/emit/typescript_tests.rs b/crates/plotnik-lib/src/infer/emit/typescript_tests.rs index 5aae21dc..4efdc466 100644 --- a/crates/plotnik-lib/src/infer/emit/typescript_tests.rs +++ b/crates/plotnik-lib/src/infer/emit/typescript_tests.rs @@ -12,8 +12,6 @@ fn emit_with_config(input: &str, config: &TypeScriptEmitConfig) -> String { emit_typescript(&table, config) } -// --- Simple Structs (Interfaces) --- - #[test] fn emit_interface_single_field() { let input = "Foo = { #Node @value }"; @@ -70,8 +68,6 @@ fn emit_interface_nested_refs() { "); } -// --- Tagged Unions --- - #[test] fn emit_tagged_union_simple() { let input = indoc! {r#" @@ -134,8 +130,6 @@ fn emit_tagged_union_with_builtins() { "#); } -// --- Wrapper Types --- - #[test] fn emit_optional_null() { let input = "MaybeNode = #Node?"; @@ -146,7 +140,7 @@ fn emit_optional_null() { fn emit_optional_undefined() { let input = "MaybeNode = #Node?"; let config = TypeScriptEmitConfig { - optional_style: OptionalStyle::Undefined, + optional: OptionalStyle::Undefined, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @"type MaybeNode = SyntaxNode | undefined;"); @@ -159,7 +153,7 @@ fn emit_optional_question_mark() { Foo = { MaybeNode @maybe } "#}; let config = TypeScriptEmitConfig { - optional_style: OptionalStyle::QuestionMark, + optional: OptionalStyle::QuestionMark, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @r" @@ -249,8 +243,6 @@ fn emit_list_of_optionals() { "); } -// --- Config Variations --- - #[test] fn emit_with_export() { let input = "Foo = { #Node @value }"; @@ -284,7 +276,7 @@ fn emit_readonly_fields() { fn emit_custom_node_type() { let input = "Foo = { #Node @value }"; let config = TypeScriptEmitConfig { - node_type_name: "TSNode".to_string(), + node_type: "TSNode".to_string(), ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @r" @@ -298,7 +290,7 @@ fn emit_custom_node_type() { fn emit_type_alias_instead_of_interface() { let input = "Foo = { #Node @value #string @name }"; let config = TypeScriptEmitConfig { - use_type_alias: true, + type_alias: true, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @"type Foo = { value: SyntaxNode; name: string };"); @@ -308,7 +300,7 @@ fn emit_type_alias_instead_of_interface() { fn emit_type_alias_empty() { let input = "Empty = {}"; let config = TypeScriptEmitConfig { - use_type_alias: true, + type_alias: true, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @"type Empty = {};"); @@ -321,7 +313,7 @@ fn emit_type_alias_nested() { Outer = { Inner @inner #string @label } "#}; let config = TypeScriptEmitConfig { - use_type_alias: true, + type_alias: true, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @r" @@ -337,7 +329,7 @@ fn emit_no_inline_synthetic() { Container = { @inner } "#}; let config = TypeScriptEmitConfig { - inline_synthetic: false, + nested: true, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @r" @@ -359,8 +351,6 @@ fn emit_inline_synthetic() { "); } -// --- Complex Scenarios --- - #[test] fn emit_complex_program() { let input = indoc! {r#" @@ -441,13 +431,13 @@ fn emit_all_config_options() { Items = Item* "#}; let config = TypeScriptEmitConfig { - optional_style: OptionalStyle::QuestionMark, + optional: OptionalStyle::QuestionMark, export: true, readonly: true, - inline_synthetic: true, - node_type_name: "ASTNode".to_string(), - use_type_alias: false, - default_query_name: "QueryResult".to_string(), + nested: false, + node_type: "ASTNode".to_string(), + type_alias: false, + entry_name: "QueryResult".to_string(), }; insta::assert_snapshot!(emit_with_config(input, &config), @r" export type MaybeNode = ASTNode; @@ -461,8 +451,6 @@ fn emit_all_config_options() { "); } -// --- Edge Cases --- - #[test] fn emit_single_variant_union() { let input = indoc! {r#" @@ -554,7 +542,7 @@ fn emit_optional_in_struct_undefined_style() { Container = { MaybeNode @item #string @name } "#}; let config = TypeScriptEmitConfig { - optional_style: OptionalStyle::Undefined, + optional: OptionalStyle::Undefined, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @r" @@ -576,7 +564,7 @@ fn emit_tagged_union_with_optional_field_question_mark() { Choice = [ A: VariantA B: VariantB ] "#}; let config = TypeScriptEmitConfig { - optional_style: OptionalStyle::QuestionMark, + optional: OptionalStyle::QuestionMark, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @r#" @@ -645,7 +633,7 @@ fn emit_struct_with_forward_ref() { fn emit_synthetic_type_no_inline() { let input = " = { #Node @value }"; let config = TypeScriptEmitConfig { - inline_synthetic: false, + nested: true, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @r" @@ -659,7 +647,7 @@ fn emit_synthetic_type_no_inline() { fn emit_synthetic_type_with_inline() { let input = " = { #Node @value }"; let config = TypeScriptEmitConfig { - inline_synthetic: true, + nested: false, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @""); @@ -706,7 +694,7 @@ fn emit_field_referencing_unknown_type() { fn emit_empty_interface_no_type_alias() { let input = "Empty = {}"; let config = TypeScriptEmitConfig { - use_type_alias: false, + type_alias: false, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @"interface Empty {}"); @@ -720,8 +708,8 @@ fn emit_inline_synthetic_struct_with_optional_field() { Container = { @inner } "#}; let config = TypeScriptEmitConfig { - inline_synthetic: true, - optional_style: OptionalStyle::QuestionMark, + nested: false, + optional: OptionalStyle::QuestionMark, ..Default::default() }; insta::assert_snapshot!(emit_with_config(input, &config), @r" @@ -743,8 +731,6 @@ fn emit_builtin_value_with_named_key() { insta::assert_snapshot!(emit(input), @""); } -// --- DefaultQuery --- - #[test] fn emit_default_query_interface() { let input = "#DefaultQuery = { #Node @value }"; @@ -760,7 +746,7 @@ fn emit_default_query_interface() { fn emit_default_query_custom_name() { let input = "#DefaultQuery = { #Node @value }"; let config = TypeScriptEmitConfig { - default_query_name: "MyResult".to_string(), + entry_name: "MyResult".to_string(), ..Default::default() }; diff --git a/crates/plotnik-lib/src/infer/mod.rs b/crates/plotnik-lib/src/infer/mod.rs index 46471372..32da41d3 100644 --- a/crates/plotnik-lib/src/infer/mod.rs +++ b/crates/plotnik-lib/src/infer/mod.rs @@ -3,10 +3,10 @@ //! This module provides: //! - `TypeTable`: collection of inferred types //! - `TypeKey` / `TypeValue`: type representation -//! - `emit_rust`: Rust code emitter -//! - `emit_typescript`: TypeScript code emitter +//! - `TypePrinter`: builder for emitting types as code pub mod emit; +mod printer; mod types; pub mod tyton; @@ -15,7 +15,6 @@ mod types_tests; #[cfg(test)] mod tyton_tests; -pub use emit::{ - Indirection, OptionalStyle, RustEmitConfig, TypeScriptEmitConfig, emit_rust, emit_typescript, -}; -pub use types::{TypeKey, TypeTable, TypeValue}; +pub use emit::{Indirection, OptionalStyle, RustEmitConfig, TypeScriptEmitConfig}; +pub use printer::{RustPrinter, TypePrinter, TypeScriptPrinter}; +pub use types::{MergedField, TypeKey, TypeTable, TypeValue}; diff --git a/crates/plotnik-lib/src/infer/printer.rs b/crates/plotnik-lib/src/infer/printer.rs new file mode 100644 index 00000000..4044001c --- /dev/null +++ b/crates/plotnik-lib/src/infer/printer.rs @@ -0,0 +1,151 @@ +//! Builder-pattern printer for emitting inferred types as code. +//! +//! # Example +//! +//! ```ignore +//! let code = query.type_printer() +//! .entry_name("MyQuery") +//! .rust() +//! .derive(&["debug", "clone"]) +//! .render(); +//! ``` + +use super::TypeTable; +use super::emit::{ + Indirection, OptionalStyle, RustEmitConfig, TypeScriptEmitConfig, emit_rust, emit_typescript, +}; + +/// Builder for type emission. Use [`rust()`](Self::rust) or [`typescript()`](Self::typescript) +/// to select the target language. +pub struct TypePrinter<'src> { + table: TypeTable<'src>, + entry_name: String, +} + +impl<'src> TypePrinter<'src> { + /// Create a new type printer from a type table. + pub fn new(table: TypeTable<'src>) -> Self { + Self { + table, + entry_name: "QueryResult".to_string(), + } + } + + /// Set the name for the entry point type (default: "QueryResult"). + pub fn entry_name(mut self, name: impl Into) -> Self { + self.entry_name = name.into(); + self + } + + /// Configure Rust output. + pub fn rust(self) -> RustPrinter<'src> { + let config = RustEmitConfig { + entry_name: self.entry_name, + ..Default::default() + }; + RustPrinter { + table: self.table, + config, + } + } + + /// Configure TypeScript output. + pub fn typescript(self) -> TypeScriptPrinter<'src> { + let config = TypeScriptEmitConfig { + entry_name: self.entry_name, + ..Default::default() + }; + TypeScriptPrinter { + table: self.table, + config, + } + } +} + +/// Builder for Rust code emission. +pub struct RustPrinter<'src> { + table: TypeTable<'src>, + config: RustEmitConfig, +} + +impl<'src> RustPrinter<'src> { + /// Set indirection type for cyclic references (default: Box). + pub fn indirection(mut self, ind: Indirection) -> Self { + self.config.indirection = ind; + self + } + + /// Set derive macros from a list of trait names. + /// + /// Recognized names: "debug", "clone", "partialeq" (case-insensitive). + /// Unrecognized names are ignored. + pub fn derive(mut self, traits: &[&str]) -> Self { + self.config.derive_debug = false; + self.config.derive_clone = false; + self.config.derive_partial_eq = false; + + for t in traits { + match t.to_lowercase().as_str() { + "debug" => self.config.derive_debug = true, + "clone" => self.config.derive_clone = true, + "partialeq" => self.config.derive_partial_eq = true, + _ => {} + } + } + self + } + + /// Render the type definitions as Rust code. + pub fn render(&self) -> String { + emit_rust(&self.table, &self.config) + } +} + +/// Builder for TypeScript code emission. +pub struct TypeScriptPrinter<'src> { + table: TypeTable<'src>, + config: TypeScriptEmitConfig, +} + +impl<'src> TypeScriptPrinter<'src> { + /// Set how optional values are represented (default: Null). + pub fn optional(mut self, style: OptionalStyle) -> Self { + self.config.optional = style; + self + } + + /// Whether to add `export` keyword to types (default: false). + pub fn export(mut self, value: bool) -> Self { + self.config.export = value; + self + } + + /// Whether to make fields readonly (default: false). + pub fn readonly(mut self, value: bool) -> Self { + self.config.readonly = value; + self + } + + /// Whether to emit nested synthetic types instead of inlining (default: false). + pub fn nested(mut self, value: bool) -> Self { + self.config.nested = value; + self + } + + /// Set the name for the Node type (default: "SyntaxNode"). + pub fn node_type(mut self, name: impl Into) -> Self { + self.config.node_type = name.into(); + self + } + + /// Whether to use `type Foo = ...` instead of `interface Foo { ... }` (default: false). + pub fn type_alias(mut self, value: bool) -> Self { + self.config.type_alias = value; + self + } + + /// Render the type definitions as TypeScript code. + pub fn render(&self) -> String { + emit_typescript(&self.table, &self.config) + } +} diff --git a/crates/plotnik-lib/src/infer/types.rs b/crates/plotnik-lib/src/infer/types.rs index 6e9081bd..aaad9c02 100644 --- a/crates/plotnik-lib/src/infer/types.rs +++ b/crates/plotnik-lib/src/infer/types.rs @@ -49,6 +49,7 @@ //! name collisions while keeping names readable. use indexmap::IndexMap; +use rowan::TextRange; /// Identity of a type in the type table. #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -67,21 +68,36 @@ pub enum TypeKey<'src> { DefaultQuery, /// User-provided type name via `:: TypeName` Named(&'src str), - /// Path-based synthetic name: ["Foo", "bar"] → FooBar - Synthetic(Vec<&'src str>), + /// Synthetic type derived from parent + capture name. + /// Parent can be Named, DefaultQuery, or another Synthetic. + /// Emitter resolves parent to name, then appends capture name in PascalCase. + Synthetic { + parent: Box>, + name: &'src str, + }, } impl TypeKey<'_> { /// Render as PascalCase type name. + /// For Synthetic keys with DefaultQuery parent, uses "DefaultQuery" as the parent name. + /// Use `to_pascal_case_with_entry_name` to customize the DefaultQuery name. pub fn to_pascal_case(&self) -> String { + self.to_pascal_case_with_entry_name("DefaultQuery") + } + + /// Render as PascalCase type name, using the given entry_name for DefaultQuery. + pub fn to_pascal_case_with_entry_name(&self, entry_name: &str) -> String { match self { TypeKey::Node => "Node".to_string(), TypeKey::String => "String".to_string(), TypeKey::Unit => "Unit".to_string(), TypeKey::Invalid => "Unit".to_string(), // Invalid emits as Unit - TypeKey::DefaultQuery => "DefaultQuery".to_string(), + TypeKey::DefaultQuery => entry_name.to_string(), TypeKey::Named(name) => (*name).to_string(), - TypeKey::Synthetic(segments) => segments.iter().map(|s| to_pascal(s)).collect(), + TypeKey::Synthetic { parent, name } => { + let parent_name = parent.to_pascal_case_with_entry_name(entry_name); + format!("{}{}", parent_name, to_pascal(name)) + } } } @@ -155,6 +171,8 @@ pub struct TypeTable<'src> { pub types: IndexMap, TypeValue<'src>>, /// Types that contain cyclic references (need Box in Rust). pub cyclic: Vec>, + /// Source spans where each type was first defined. + definition_spans: IndexMap, TextRange>, } impl<'src> TypeTable<'src> { @@ -168,6 +186,7 @@ impl<'src> TypeTable<'src> { Self { types, cyclic: Vec::new(), + definition_spans: IndexMap::new(), } } @@ -177,6 +196,40 @@ impl<'src> TypeTable<'src> { key } + /// Insert a type definition with a source span, detecting conflicts. + /// + /// Returns `Ok(key)` if inserted successfully (no conflict). + /// Returns `Err(existing_span)` if there was an existing incompatible type. + /// + /// On conflict, the existing type is NOT overwritten - caller should use Invalid. + pub fn try_insert( + &mut self, + key: TypeKey<'src>, + value: TypeValue<'src>, + span: TextRange, + ) -> Result, TextRange> { + if let Some(existing) = self.types.get(&key) { + if !self.values_are_compatible(existing, &value) { + let existing_span = self.definition_spans.get(&key).copied().unwrap_or(span); + return Err(existing_span); + } + // Compatible - keep existing, don't overwrite + return Ok(key); + } + self.types.insert(key.clone(), value); + self.definition_spans.insert(key.clone(), span); + Ok(key) + } + + /// Insert without span tracking. Returns true if inserted, false if key existed. + pub fn try_insert_untracked(&mut self, key: TypeKey<'src>, value: TypeValue<'src>) -> bool { + if self.types.contains_key(&key) { + return false; + } + self.types.insert(key, value); + true + } + /// Mark a type as cyclic (requires indirection in Rust). pub fn mark_cyclic(&mut self, key: TypeKey<'src>) { if !self.cyclic.contains(&key) { @@ -194,17 +247,245 @@ impl<'src> TypeTable<'src> { self.types.get(key) } + /// Check if two type keys are structurally compatible. + /// + /// For built-in types, this is simple equality. + /// For synthetic types, we compare the underlying TypeValue structure. + /// Two synthetic keys pointing to different TaggedUnions or Structs are incompatible. + pub fn types_are_compatible(&self, a: &TypeKey<'src>, b: &TypeKey<'src>) -> bool { + if a == b { + return true; + } + + // Invalid is compatible with anything - don't cascade errors + if *a == TypeKey::Invalid || *b == TypeKey::Invalid { + return true; + } + + // Different built-in types are incompatible + if a.is_builtin() || b.is_builtin() { + return false; + } + + // For synthetic/named types, compare the underlying values + let val_a = self.get(a); + let val_b = self.get(b); + + match (val_a, val_b) { + (Some(va), Some(vb)) => self.values_are_compatible(va, vb), + // If either is missing, consider incompatible (shouldn't happen in practice) + _ => false, + } + } + + /// Check if two type values are structurally compatible. + fn values_are_compatible(&self, a: &TypeValue<'src>, b: &TypeValue<'src>) -> bool { + use TypeValue::*; + match (a, b) { + (Node, Node) => true, + (String, String) => true, + (Unit, Unit) => true, + (Invalid, Invalid) => true, + (Optional(ka), Optional(kb)) => self.types_are_compatible(ka, kb), + (List(ka), List(kb)) => self.types_are_compatible(ka, kb), + (NonEmptyList(ka), NonEmptyList(kb)) => self.types_are_compatible(ka, kb), + // List and NonEmptyList are compatible if inner types match - merge to List + (List(ka), NonEmptyList(kb)) | (NonEmptyList(ka), List(kb)) => { + self.types_are_compatible(ka, kb) + } + // Optional and T are compatible - merge to Optional + (Optional(k), other) | (other, Optional(k)) => { + let other_as_key = match other { + Node => TypeKey::Node, + String => TypeKey::String, + _ => return false, + }; + self.types_are_compatible(k, &other_as_key) + } + (Struct(fa), Struct(fb)) => { + // Structs must have exactly the same fields with compatible types + if fa.len() != fb.len() { + return false; + } + for (name, key_a) in fa { + match fb.get(name) { + Some(key_b) => { + if !self.types_are_compatible(key_a, key_b) { + return false; + } + } + None => return false, + } + } + true + } + (TaggedUnion(va), TaggedUnion(vb)) => { + // TaggedUnions must have exactly the same variants + if va.len() != vb.len() { + return false; + } + for (name, key_a) in va { + match vb.get(name) { + Some(key_b) => { + if !self.types_are_compatible(key_a, key_b) { + return false; + } + } + None => return false, + } + } + true + } + // Different type constructors are incompatible + _ => false, + } + } + /// Iterate over all types in insertion order. pub fn iter(&self) -> impl Iterator, &TypeValue<'src>)> { self.types.iter() } + /// Try to merge List and NonEmptyList types into List. + /// + /// Returns `Some(List(inner))` if one is List and other is NonEmptyList with compatible inner types. + /// Returns `None` otherwise. + fn try_merge_list_types( + &mut self, + a: &TypeKey<'src>, + b: &TypeKey<'src>, + ) -> Option> { + let val_a = self.get(a)?; + let val_b = self.get(b)?; + + let inner = match (val_a, val_b) { + (TypeValue::List(ka), TypeValue::NonEmptyList(kb)) + | (TypeValue::NonEmptyList(ka), TypeValue::List(kb)) => { + if self.types_are_compatible(ka, kb) { + ka.clone() + } else { + return None; + } + } + _ => return None, + }; + + // Return or create a List type with the inner type + let list_key = TypeKey::Named(Box::leak("ListMerged".to_string().into_boxed_str())); + self.insert(list_key.clone(), TypeValue::List(inner)); + Some(list_key) + } + + /// Try to merge Optional and T into Optional. + /// + /// Returns `Some(Optional(inner))` if one is Optional and other is the unwrapped type. + /// Returns `None` otherwise. + fn try_merge_optional_types( + &mut self, + a: &TypeKey<'src>, + b: &TypeKey<'src>, + ) -> Option> { + let val_a = self.get(a); + let val_b = self.get(b); + + // Handle cases where one is a wrapper type (Optional) around the other + match (val_a, val_b) { + (Some(TypeValue::Optional(ka)), Some(TypeValue::Optional(kb))) => { + // Both optional - check inner compatibility + if self.types_are_compatible(ka, kb) { + return Some(a.clone()); + } + None + } + (Some(TypeValue::Optional(k)), _) => { + if self.types_are_compatible(k, b) { + return Some(a.clone()); + } + None + } + (_, Some(TypeValue::Optional(k))) => { + if self.types_are_compatible(a, k) { + return Some(b.clone()); + } + None + } + _ => None, + } + } + + /// Try to merge two struct types into one, returning the merged fields. + /// + /// Returns `Some(merged_fields)` if both types are structs (regardless of field shape). + /// Returns `None` if either type is not a struct. + /// + /// The merge rules: + /// - Fields present in both structs with compatible types keep that type + /// - Fields present in only one struct become Optional + /// - Fields with conflicting types become Invalid + fn try_merge_struct_fields( + &self, + a: &TypeKey<'src>, + b: &TypeKey<'src>, + ) -> Option>> { + let val_a = self.get(a)?; + let val_b = self.get(b)?; + + let (fields_a, fields_b) = match (val_a, val_b) { + (TypeValue::Struct(fa), TypeValue::Struct(fb)) => (fa, fb), + _ => return None, + }; + + // Collect all field names from both structs + let mut all_fields: IndexMap<&'src str, ()> = IndexMap::new(); + for name in fields_a.keys() { + all_fields.entry(*name).or_insert(()); + } + for name in fields_b.keys() { + all_fields.entry(*name).or_insert(()); + } + + let mut result = IndexMap::new(); + for field_name in all_fields.keys() { + let type_a = fields_a.get(field_name); + let type_b = fields_b.get(field_name); + + let merged = match (type_a, type_b) { + (Some(ta), Some(tb)) => { + if self.types_are_compatible(ta, tb) { + MergedField::Same(ta.clone()) + } else { + // Recursively try to merge nested structs + if let Some(nested_merged) = self.try_merge_struct_fields(ta, tb) { + if nested_merged + .values() + .any(|m| matches!(m, MergedField::Conflict)) + { + MergedField::Conflict + } else { + // Both are structs - they can be merged (caller handles actual merge) + MergedField::Same(ta.clone()) + } + } else { + MergedField::Conflict + } + } + } + (Some(t), None) | (None, Some(t)) => MergedField::Optional(t.clone()), + (None, None) => continue, + }; + result.insert(*field_name, merged); + } + + Some(result) + } + /// Merge fields from multiple struct branches (for untagged unions). /// /// Given a list of field maps (one per branch), produces a merged field map where: /// - Fields present in all branches with the same type keep that type /// - Fields present in only some branches become Optional /// - Fields with conflicting types across branches become Invalid + /// - Fields that are both structs get recursively merged /// /// # Example /// @@ -213,6 +494,13 @@ impl<'src> TypeTable<'src> { /// /// Merged: `{ name: String, value: Optional, extra: Optional }` /// + /// # Struct Merge Example + /// + /// Branch 1: `{ x: { y: Node } }` + /// Branch 2: `{ x: { z: Node } }` + /// + /// Merged: `{ x: { y: Optional, z: Optional } }` + /// /// # Type Conflict Example /// /// Branch 1: `{ x: String }` @@ -220,6 +508,7 @@ impl<'src> TypeTable<'src> { /// /// Merged: `{ x: Invalid }` (with diagnostic warning) pub fn merge_fields( + &mut self, branches: &[IndexMap<&'src str, TypeKey<'src>>], ) -> IndexMap<&'src str, MergedField<'src>> { if branches.is_empty() { @@ -251,13 +540,83 @@ impl<'src> TypeTable<'src> { continue; } - // Check if all occurrences have the same type + // Check if all occurrences have compatible types (structural comparison) let first_type = type_occurrences[0]; - let all_same_type = type_occurrences.iter().all(|t| *t == first_type); + let all_same_type = type_occurrences + .iter() + .all(|t| self.types_are_compatible(t, first_type)); - let merged = if !all_same_type { - // Type conflict - MergedField::Conflict + // Check for List/NonEmptyList merge case + let list_merge_key = if type_occurrences.len() == 2 { + self.try_merge_list_types(type_occurrences[0], type_occurrences[1]) + } else { + None + }; + + // Check for Optional/Required merge case + let optional_merge_key = if type_occurrences.len() == 2 && list_merge_key.is_none() { + self.try_merge_optional_types(type_occurrences[0], type_occurrences[1]) + } else { + None + }; + + let merged = if let Some(merged_key) = optional_merge_key { + // Optional merge result is already Optional - don't double-wrap + MergedField::Same(merged_key) + } else if let Some(merged_key) = list_merge_key { + // List and NonEmptyList merged to List + if present_count == branch_count { + MergedField::Same(merged_key) + } else { + MergedField::Optional(merged_key) + } + } else if !all_same_type { + // Types differ - try to merge if both are structs + if type_occurrences.len() == 2 { + if let Some(struct_merged) = + self.try_merge_struct_fields(type_occurrences[0], type_occurrences[1]) + { + // Both are structs - create a merged struct type + let merged_fields: IndexMap<&'src str, TypeKey<'src>> = struct_merged + .into_iter() + .map(|(name, mf)| { + let key = match mf { + MergedField::Same(k) => k, + MergedField::Optional(k) => { + let wrapper_name = format!( + "{}{}Opt", + to_pascal(field_name), + to_pascal(name) + ); + let wrapper_key = TypeKey::Named(Box::leak( + wrapper_name.into_boxed_str(), + )); + self.insert(wrapper_key.clone(), TypeValue::Optional(k)); + wrapper_key + } + MergedField::Conflict => TypeKey::Invalid, + }; + (name, key) + }) + .collect(); + + // Create a new merged struct type + let merged_name = format!("{}Merged", to_pascal(field_name)); + let merged_key = TypeKey::Named(Box::leak(merged_name.into_boxed_str())); + self.insert(merged_key.clone(), TypeValue::Struct(merged_fields)); + + if present_count == branch_count { + MergedField::Same(merged_key) + } else { + MergedField::Optional(merged_key) + } + } else { + MergedField::Conflict + } + } else { + // More than 2 branches with different struct types - TODO: support N-way merge + MergedField::Conflict + } } else if present_count == branch_count { // Present in all branches with same type MergedField::Same(first_type.clone()) diff --git a/crates/plotnik-lib/src/infer/types_tests.rs b/crates/plotnik-lib/src/infer/types_tests.rs index 32299deb..f81d046b 100644 --- a/crates/plotnik-lib/src/infer/types_tests.rs +++ b/crates/plotnik-lib/src/infer/types_tests.rs @@ -20,13 +20,23 @@ fn type_key_to_pascal_case_named() { #[test] fn type_key_to_pascal_case_synthetic() { - assert_eq!(TypeKey::Synthetic(vec!["Foo"]).to_pascal_case(), "Foo"); assert_eq!( - TypeKey::Synthetic(vec!["Foo", "bar"]).to_pascal_case(), + TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("Foo")), + name: "bar" + } + .to_pascal_case(), "FooBar" ); assert_eq!( - TypeKey::Synthetic(vec!["Foo", "bar", "baz"]).to_pascal_case(), + TypeKey::Synthetic { + parent: Box::new(TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("Foo")), + name: "bar" + }), + name: "baz" + } + .to_pascal_case(), "FooBarBaz" ); } @@ -34,11 +44,19 @@ fn type_key_to_pascal_case_synthetic() { #[test] fn type_key_to_pascal_case_snake_case_segments() { assert_eq!( - TypeKey::Synthetic(vec!["Foo", "bar_baz"]).to_pascal_case(), + TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("Foo")), + name: "bar_baz" + } + .to_pascal_case(), "FooBarBaz" ); assert_eq!( - TypeKey::Synthetic(vec!["function_info", "params"]).to_pascal_case(), + TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("FunctionInfo")), + name: "params" + } + .to_pascal_case(), "FunctionInfoParams" ); } @@ -127,21 +145,23 @@ fn type_value_tagged_union() { let mut assign_fields = IndexMap::new(); assign_fields.insert("target", TypeKey::String); - table.insert( - TypeKey::Synthetic(vec!["Stmt", "Assign"]), - TypeValue::Struct(assign_fields), - ); + let assign_key = TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("Stmt")), + name: "Assign", + }; + table.insert(assign_key.clone(), TypeValue::Struct(assign_fields)); let mut call_fields = IndexMap::new(); call_fields.insert("func", TypeKey::String); - table.insert( - TypeKey::Synthetic(vec!["Stmt", "Call"]), - TypeValue::Struct(call_fields), - ); + let call_key = TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("Stmt")), + name: "Call", + }; + table.insert(call_key.clone(), TypeValue::Struct(call_fields)); let mut variants = IndexMap::new(); - variants.insert("Assign", TypeKey::Synthetic(vec!["Stmt", "Assign"])); - variants.insert("Call", TypeKey::Synthetic(vec!["Stmt", "Call"])); + variants.insert("Assign", assign_key); + variants.insert("Call", call_key); let union = TypeValue::TaggedUnion(variants); table.insert(TypeKey::Named("Stmt"), union); @@ -201,12 +221,24 @@ fn type_key_equality() { assert_eq!(TypeKey::Named("Foo"), TypeKey::Named("Foo")); assert_ne!(TypeKey::Named("Foo"), TypeKey::Named("Bar")); assert_eq!( - TypeKey::Synthetic(vec!["a", "b"]), - TypeKey::Synthetic(vec!["a", "b"]) + TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("A")), + name: "b" + }, + TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("A")), + name: "b" + } ); assert_ne!( - TypeKey::Synthetic(vec!["a", "b"]), - TypeKey::Synthetic(vec!["a", "c"]) + TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("A")), + name: "b" + }, + TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("A")), + name: "c" + } ); } @@ -216,11 +248,17 @@ fn type_key_hash_consistency() { let mut set = HashSet::new(); set.insert(TypeKey::Node); set.insert(TypeKey::Named("Foo")); - set.insert(TypeKey::Synthetic(vec!["a", "b"])); + set.insert(TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("A")), + name: "b", + }); assert!(set.contains(&TypeKey::Node)); assert!(set.contains(&TypeKey::Named("Foo"))); - assert!(set.contains(&TypeKey::Synthetic(vec!["a", "b"]))); + assert!(set.contains(&TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("A")), + name: "b" + })); assert!(!set.contains(&TypeKey::String)); } @@ -231,7 +269,13 @@ fn type_key_is_builtin() { assert!(TypeKey::Unit.is_builtin()); assert!(TypeKey::Invalid.is_builtin()); assert!(!TypeKey::Named("Foo").is_builtin()); - assert!(!TypeKey::Synthetic(vec!["a"]).is_builtin()); + assert!( + !TypeKey::Synthetic { + parent: Box::new(TypeKey::Named("A")), + name: "b" + } + .is_builtin() + ); } #[test] @@ -242,20 +286,22 @@ fn type_value_invalid() { #[test] fn merge_fields_empty_branches() { + let mut table = TypeTable::new(); let branches: Vec> = vec![]; - let merged = TypeTable::merge_fields(&branches); + let merged = table.merge_fields(&branches); assert!(merged.is_empty()); } #[test] fn merge_fields_single_branch() { + let mut table = TypeTable::new(); let mut branch = IndexMap::new(); branch.insert("name", TypeKey::String); branch.insert("value", TypeKey::Node); - let merged = TypeTable::merge_fields(&[branch]); + let merged = table.merge_fields(&[branch]); assert_eq!(merged.len(), 2); assert_eq!(merged["name"], MergedField::Same(TypeKey::String)); @@ -264,13 +310,14 @@ fn merge_fields_single_branch() { #[test] fn merge_fields_identical_branches() { + let mut table = TypeTable::new(); let mut branch1 = IndexMap::new(); branch1.insert("name", TypeKey::String); let mut branch2 = IndexMap::new(); branch2.insert("name", TypeKey::String); - let merged = TypeTable::merge_fields(&[branch1, branch2]); + let merged = table.merge_fields(&[branch1, branch2]); assert_eq!(merged.len(), 1); assert_eq!(merged["name"], MergedField::Same(TypeKey::String)); @@ -278,6 +325,7 @@ fn merge_fields_identical_branches() { #[test] fn merge_fields_missing_in_some_branches() { + let mut table = TypeTable::new(); let mut branch1 = IndexMap::new(); branch1.insert("name", TypeKey::String); branch1.insert("value", TypeKey::Node); @@ -286,7 +334,7 @@ fn merge_fields_missing_in_some_branches() { branch2.insert("name", TypeKey::String); // value missing - let merged = TypeTable::merge_fields(&[branch1, branch2]); + let merged = table.merge_fields(&[branch1, branch2]); assert_eq!(merged.len(), 2); assert_eq!(merged["name"], MergedField::Same(TypeKey::String)); @@ -295,13 +343,14 @@ fn merge_fields_missing_in_some_branches() { #[test] fn merge_fields_disjoint_branches() { + let mut table = TypeTable::new(); let mut branch1 = IndexMap::new(); branch1.insert("a", TypeKey::String); let mut branch2 = IndexMap::new(); branch2.insert("b", TypeKey::Node); - let merged = TypeTable::merge_fields(&[branch1, branch2]); + let merged = table.merge_fields(&[branch1, branch2]); assert_eq!(merged.len(), 2); assert_eq!(merged["a"], MergedField::Optional(TypeKey::String)); @@ -310,13 +359,14 @@ fn merge_fields_disjoint_branches() { #[test] fn merge_fields_type_conflict() { + let mut table = TypeTable::new(); let mut branch1 = IndexMap::new(); branch1.insert("x", TypeKey::String); let mut branch2 = IndexMap::new(); branch2.insert("x", TypeKey::Node); - let merged = TypeTable::merge_fields(&[branch1, branch2]); + let merged = table.merge_fields(&[branch1, branch2]); assert_eq!(merged.len(), 1); assert_eq!(merged["x"], MergedField::Conflict); @@ -324,6 +374,7 @@ fn merge_fields_type_conflict() { #[test] fn merge_fields_partial_conflict() { + let mut table = TypeTable::new(); // Three branches: x is String in branch 1 and 2, Node in branch 3 let mut branch1 = IndexMap::new(); branch1.insert("x", TypeKey::String); @@ -334,13 +385,14 @@ fn merge_fields_partial_conflict() { let mut branch3 = IndexMap::new(); branch3.insert("x", TypeKey::Node); - let merged = TypeTable::merge_fields(&[branch1, branch2, branch3]); + let merged = table.merge_fields(&[branch1, branch2, branch3]); assert_eq!(merged["x"], MergedField::Conflict); } #[test] fn merge_fields_complex_scenario() { + let mut table = TypeTable::new(); // Branch 1: { name: String, value: Node } // Branch 2: { name: String, extra: Node } // Result: { name: String, value: Optional, extra: Optional } @@ -352,7 +404,7 @@ fn merge_fields_complex_scenario() { branch2.insert("name", TypeKey::String); branch2.insert("extra", TypeKey::Node); - let merged = TypeTable::merge_fields(&[branch1, branch2]); + let merged = table.merge_fields(&[branch1, branch2]); assert_eq!(merged.len(), 3); assert_eq!(merged["name"], MergedField::Same(TypeKey::String)); @@ -362,6 +414,7 @@ fn merge_fields_complex_scenario() { #[test] fn merge_fields_preserves_order() { + let mut table = TypeTable::new(); let mut branch1 = IndexMap::new(); branch1.insert("z", TypeKey::String); branch1.insert("a", TypeKey::String); @@ -369,7 +422,7 @@ fn merge_fields_preserves_order() { let mut branch2 = IndexMap::new(); branch2.insert("m", TypeKey::String); - let merged = TypeTable::merge_fields(&[branch1, branch2]); + let merged = table.merge_fields(&[branch1, branch2]); let keys: Vec<_> = merged.keys().collect(); // Order follows first occurrence across branches diff --git a/crates/plotnik-lib/src/infer/tyton.rs b/crates/plotnik-lib/src/infer/tyton.rs index e3a89364..f026d38a 100644 --- a/crates/plotnik-lib/src/infer/tyton.rs +++ b/crates/plotnik-lib/src/infer/tyton.rs @@ -239,8 +239,34 @@ impl<'src> Parser<'src> { fn parse_synthetic_key(&mut self) -> Result, ParseError> { self.expect(Token::LAngle)?; - let mut segments = Vec::new(); + // Parse parent key first + let parent_span = self.current_span(); + let parent: TypeKey<'src> = match self.peek() { + Some(Token::DefaultQuery) => { + self.advance(); + TypeKey::DefaultQuery + } + Some(Token::UpperIdent(s)) => { + let s = *s; + self.advance(); + TypeKey::Named(s) + } + Some(Token::LAngle) => { + // Nested synthetic key + self.parse_synthetic_key()? + } + _ => { + return Err(ParseError { + message: "expected parent key (uppercase name, #DefaultQuery, or <...>)" + .to_string(), + span: parent_span, + }); + } + }; + + // Parse path segments, building nested Synthetic keys + let mut result = parent; loop { let span = self.current_span(); match self.peek() { @@ -248,33 +274,24 @@ impl<'src> Parser<'src> { self.advance(); break; } - Some(Token::UpperIdent(s)) => { - let s = *s; - self.advance(); - segments.push(s); - } Some(Token::LowerIdent(s)) => { let s = *s; self.advance(); - segments.push(s); + result = TypeKey::Synthetic { + parent: Box::new(result), + name: s, + }; } _ => { return Err(ParseError { - message: "expected identifier or '>'".to_string(), + message: "expected path segment (lowercase) or '>'".to_string(), span, }); } } } - if segments.is_empty() { - return Err(ParseError { - message: "synthetic key cannot be empty".to_string(), - span: self.current_span(), - }); - } - - Ok(TypeKey::Synthetic(segments)) + Ok(result) } fn parse_type_value(&mut self) -> Result, ParseError> { @@ -490,12 +507,20 @@ fn emit_key(out: &mut String, key: &TypeKey<'_>) { TypeKey::Unit => out.push_str("()"), TypeKey::DefaultQuery => out.push_str("#DefaultQuery"), TypeKey::Named(name) => out.push_str(name), - TypeKey::Synthetic(segments) => { + TypeKey::Synthetic { parent, name } => { + // Flatten nested Synthetic keys into + let mut segments = vec![*name]; + let mut current = parent.as_ref(); + while let TypeKey::Synthetic { parent: p, name: n } = current { + segments.push(*n); + current = p.as_ref(); + } + segments.reverse(); + out.push('<'); - for (i, seg) in segments.iter().enumerate() { - if i > 0 { - out.push(' '); - } + emit_key(out, current); + for seg in segments { + out.push(' '); out.push_str(seg); } out.push('>'); diff --git a/crates/plotnik-lib/src/infer/tyton_tests.rs b/crates/plotnik-lib/src/infer/tyton_tests.rs index c948f295..4140204d 100644 --- a/crates/plotnik-lib/src/infer/tyton_tests.rs +++ b/crates/plotnik-lib/src/infer/tyton_tests.rs @@ -176,7 +176,7 @@ fn parse_synthetic_key_simple() { String = String Unit = Unit Invalid = Invalid - Named("Wrapper") = Optional(Synthetic(["Foo", "bar"])) + Named("Wrapper") = Optional(Synthetic { parent: Named("Foo"), name: "bar" }) "#); } @@ -188,7 +188,7 @@ fn parse_synthetic_key_multiple_segments() { String = String Unit = Unit Invalid = Invalid - Named("Wrapper") = List(Synthetic(["Foo", "bar", "baz"])) + Named("Wrapper") = List(Synthetic { parent: Synthetic { parent: Named("Foo"), name: "bar" }, name: "baz" }) "#); } @@ -200,7 +200,7 @@ fn parse_struct_with_synthetic() { String = String Unit = Unit Invalid = Invalid - Named("Container") = Struct({"inner": Synthetic(["Inner", "field"])}) + Named("Container") = Struct({"inner": Synthetic { parent: Named("Inner"), name: "field" }}) "#); } @@ -212,7 +212,7 @@ fn parse_union_with_synthetic() { String = String Unit = Unit Invalid = Invalid - Named("Choice") = TaggedUnion({"First": Synthetic(["Choice", "first"]), "Second": Synthetic(["Choice", "second"])}) + Named("Choice") = TaggedUnion({"First": Synthetic { parent: Named("Choice"), name: "first" }, "Second": Synthetic { parent: Named("Choice"), name: "second" }}) "#); } @@ -327,7 +327,7 @@ fn error_missing_colon_in_union() { #[test] fn error_empty_synthetic() { let input = "Foo = <>?"; - insta::assert_snapshot!(dump_table(input), @"ERROR: synthetic key cannot be empty at 8..9"); + insta::assert_snapshot!(dump_table(input), @"ERROR: expected parent key (uppercase name, #DefaultQuery, or <...>) at 7..8"); } #[test] @@ -410,7 +410,7 @@ fn parse_synthetic_definition_struct() { String = String Unit = Unit Invalid = Invalid - Synthetic(["Foo", "bar"]) = Struct({"value": Node}) + Synthetic { parent: Named("Foo"), name: "bar" } = Struct({"value": Node}) "#); } @@ -422,7 +422,7 @@ fn parse_synthetic_definition_union() { String = String Unit = Unit Invalid = Invalid - Synthetic(["Choice", "first"]) = TaggedUnion({"A": Node, "B": String}) + Synthetic { parent: Named("Choice"), name: "first" } = TaggedUnion({"A": Node, "B": String}) "#); } @@ -434,7 +434,7 @@ fn parse_synthetic_definition_wrapper() { String = String Unit = Unit Invalid = Invalid - Synthetic(["Inner", "nested"]) = Optional(Node) + Synthetic { parent: Named("Inner"), name: "nested" } = Optional(Node) "#); } @@ -459,7 +459,7 @@ fn error_eof_expecting_colon() { #[test] fn error_invalid_token_in_synthetic() { let input = "Foo = ?"; - insta::assert_snapshot!(dump_table(input), @"ERROR: expected identifier or '>' at 9..10"); + insta::assert_snapshot!(dump_table(input), @"ERROR: expected path segment (lowercase) or '>' at 9..10"); } #[test] @@ -486,8 +486,6 @@ fn error_unprefixed_string() { insta::assert_snapshot!(dump_table(input), @"ERROR: expected type value at 6..12"); } -// === emit tests === - #[test] fn emit_empty() { let table = parse("").unwrap(); diff --git a/crates/plotnik-lib/src/parser/ast.rs b/crates/plotnik-lib/src/parser/ast.rs index 420aa78b..2bd9b11b 100644 --- a/crates/plotnik-lib/src/parser/ast.rs +++ b/crates/plotnik-lib/src/parser/ast.rs @@ -330,6 +330,20 @@ impl QuantifiedExpr { }) .unwrap_or(false) } + + /// Returns true if quantifier is a list (*, *?). + pub fn is_list(&self) -> bool { + self.operator() + .map(|op| matches!(op.kind(), SyntaxKind::Star | SyntaxKind::StarQuestion)) + .unwrap_or(false) + } + + /// Returns true if quantifier is a non-empty list (+, +?). + pub fn is_non_empty_list(&self) -> bool { + self.operator() + .map(|op| matches!(op.kind(), SyntaxKind::Plus | SyntaxKind::PlusQuestion)) + .unwrap_or(false) + } } impl FieldExpr { diff --git a/crates/plotnik-lib/src/parser/core.rs b/crates/plotnik-lib/src/parser/core.rs index a451f03f..4691c2b6 100644 --- a/crates/plotnik-lib/src/parser/core.rs +++ b/crates/plotnik-lib/src/parser/core.rs @@ -220,6 +220,7 @@ impl<'src> Parser<'src> { pub(super) fn bump(&mut self) { assert!(!self.eof(), "bump called at EOF"); + self.drain_trivia(); self.reset_debug_fuel(); self.consume_exec_fuel(); diff --git a/crates/plotnik-lib/src/parser/cst.rs b/crates/plotnik-lib/src/parser/cst.rs index e82020ee..fbe36765 100644 --- a/crates/plotnik-lib/src/parser/cst.rs +++ b/crates/plotnik-lib/src/parser/cst.rs @@ -136,7 +136,6 @@ pub enum SyntaxKind { Garbage, Error, - // --- Node kinds (non-terminals) --- Root, Tree, Ref, diff --git a/crates/plotnik-lib/src/parser/lexer_tests.rs b/crates/plotnik-lib/src/parser/lexer_tests.rs index 9fd47c3c..9bd26b1e 100644 --- a/crates/plotnik-lib/src/parser/lexer_tests.rs +++ b/crates/plotnik-lib/src/parser/lexer_tests.rs @@ -5,6 +5,28 @@ fn snapshot(input: &str) -> String { format_tokens(input, false) } +/// Format tokens with spans for debugging +#[allow(dead_code)] +fn snapshot_with_spans(input: &str) -> String { + let tokens = lex(input); + let mut out = String::new(); + for token in tokens { + if !token.kind.is_trivia() { + let start: usize = token.span.start().into(); + let end: usize = token.span.end().into(); + out.push_str(&format!( + "{:?} {:?} @ {}..{} (source: {:?})\n", + token.kind, + token_text(input, &token), + start, + end, + &input[start..end] + )); + } + } + out +} + /// Format tokens with trivia included fn snapshot_raw(input: &str) -> String { format_tokens(input, true) @@ -161,6 +183,14 @@ fn capture_simple() { "#); } +#[test] +fn capture_spans_debug() { + let input = "(identifier) @name :: string"; + eprintln!("Input: {:?}", input); + eprintln!("Tokens with spans:"); + eprintln!("{}", snapshot_with_spans(input)); +} + #[test] fn capture_with_underscores() { insta::assert_snapshot!(snapshot("@my_capture_name"), @r#" diff --git a/crates/plotnik-lib/src/parser/tests/grammar/trivia_tests.rs b/crates/plotnik-lib/src/parser/tests/grammar/trivia_tests.rs index 886def5f..d5916971 100644 --- a/crates/plotnik-lib/src/parser/tests/grammar/trivia_tests.rs +++ b/crates/plotnik-lib/src/parser/tests/grammar/trivia_tests.rs @@ -17,9 +17,9 @@ fn whitespace_preserved() { ParenOpen "(" Id "identifier" ParenClose ")" + Whitespace " " At "@" Id "name" - Whitespace " " Newline "\n" "#); } @@ -127,9 +127,9 @@ fn trivia_between_alternation_items() { ParenOpen "(" Id "b" ParenClose ")" + Newline "\n" BracketClose "]" Newline "\n" - Newline "\n" "#); } diff --git a/crates/plotnik-lib/src/query/dump.rs b/crates/plotnik-lib/src/query/dump.rs index 9f2f7219..baae64e2 100644 --- a/crates/plotnik-lib/src/query/dump.rs +++ b/crates/plotnik-lib/src/query/dump.rs @@ -2,7 +2,7 @@ #[cfg(test)] mod test_helpers { - use crate::Query; + use crate::{Query, infer::OptionalStyle}; impl Query<'_> { pub fn dump_cst(&self) -> String { @@ -36,5 +36,14 @@ mod test_helpers { pub fn dump_diagnostics_raw(&self) -> String { self.diagnostics_raw().render(self.source) } + + pub fn dump_types(&self) -> String { + self.type_printer() + .typescript() + .optional(OptionalStyle::QuestionMark) + .type_alias(true) + .nested(true) + .render() + } } } diff --git a/crates/plotnik-lib/src/query/mod.rs b/crates/plotnik-lib/src/query/mod.rs index 203b4197..084ad11c 100644 --- a/crates/plotnik-lib/src/query/mod.rs +++ b/crates/plotnik-lib/src/query/mod.rs @@ -7,8 +7,11 @@ mod dump; mod invariants; mod printer; +mod types; pub use printer::QueryPrinter; +use crate::infer::TypePrinter; + pub mod alt_kinds; #[cfg(feature = "plotnik-langs")] pub mod link; @@ -30,6 +33,8 @@ mod recursion_tests; mod shapes_tests; #[cfg(test)] mod symbol_table_tests; +#[cfg(test)] +mod types_tests; use std::collections::HashMap; @@ -40,6 +45,7 @@ use rowan::GreenNodeBuilder; use crate::Result; use crate::diagnostics::Diagnostics; +use crate::infer::TypeTable; use crate::parser::cst::SyntaxKind; use crate::parser::lexer::lex; use crate::parser::{ParseResult, Parser, Root, SyntaxNode, ast}; @@ -63,6 +69,7 @@ pub struct Query<'a> { ast: Root, symbol_table: SymbolTable<'a>, shape_cardinality_table: HashMap, + type_table: TypeTable<'a>, #[cfg(feature = "plotnik-langs")] node_type_ids: HashMap<&'a str, Option>, #[cfg(feature = "plotnik-langs")] @@ -75,6 +82,7 @@ pub struct Query<'a> { resolve_diagnostics: Diagnostics, recursion_diagnostics: Diagnostics, shapes_diagnostics: Diagnostics, + type_diagnostics: Diagnostics, #[cfg(feature = "plotnik-langs")] link_diagnostics: Diagnostics, } @@ -97,6 +105,7 @@ impl<'a> Query<'a> { ast: empty_root(), symbol_table: SymbolTable::default(), shape_cardinality_table: HashMap::new(), + type_table: TypeTable::new(), #[cfg(feature = "plotnik-langs")] node_type_ids: HashMap::new(), #[cfg(feature = "plotnik-langs")] @@ -109,6 +118,7 @@ impl<'a> Query<'a> { resolve_diagnostics: Diagnostics::new(), recursion_diagnostics: Diagnostics::new(), shapes_diagnostics: Diagnostics::new(), + type_diagnostics: Diagnostics::new(), #[cfg(feature = "plotnik-langs")] link_diagnostics: Diagnostics::new(), } @@ -142,6 +152,7 @@ impl<'a> Query<'a> { self.resolve_names(); self.validate_recursion(); self.infer_shapes(); + self.infer_types(); Ok(self) } @@ -218,6 +229,7 @@ impl<'a> Query<'a> { all.extend(self.resolve_diagnostics.clone()); all.extend(self.recursion_diagnostics.clone()); all.extend(self.shapes_diagnostics.clone()); + all.extend(self.type_diagnostics.clone()); #[cfg(feature = "plotnik-langs")] all.extend(self.link_diagnostics.clone()); all @@ -239,6 +251,7 @@ impl<'a> Query<'a> { && !self.resolve_diagnostics.has_errors() && !self.recursion_diagnostics.has_errors() && !self.shapes_diagnostics.has_errors() + && !self.type_diagnostics.has_errors() && !self.link_diagnostics.has_errors() } @@ -250,6 +263,14 @@ impl<'a> Query<'a> { && !self.resolve_diagnostics.has_errors() && !self.recursion_diagnostics.has_errors() && !self.shapes_diagnostics.has_errors() + && !self.type_diagnostics.has_errors() + } + + /// Get a type printer for emitting inferred types as code. + /// + /// Returns a builder that can be configured for Rust or TypeScript output. + pub fn type_printer(&self) -> TypePrinter<'a> { + TypePrinter::new(self.type_table.clone()) } } diff --git a/crates/plotnik-lib/src/query/types.rs b/crates/plotnik-lib/src/query/types.rs new file mode 100644 index 00000000..a22984ac --- /dev/null +++ b/crates/plotnik-lib/src/query/types.rs @@ -0,0 +1,715 @@ +//! Type inference pass: AST → TypeTable. +//! +//! Walks definitions and infers output types from capture patterns. +//! Produces a `TypeTable` containing all inferred types. + +use indexmap::{IndexMap, IndexSet}; +use rowan::TextRange; + +use crate::diagnostics::DiagnosticKind; +use crate::infer::{MergedField, TypeKey, TypeTable, TypeValue}; +use crate::parser::{AltKind, Expr, ast, token_src}; + +use super::Query; + +/// Tracks a field's type and the location where it was first captured. +#[derive(Clone)] +struct FieldEntry<'src> { + type_key: TypeKey<'src>, + /// Range of the capture name token (e.g., `@x`) + capture_range: TextRange, +} + +impl<'a> Query<'a> { + pub(super) fn infer_types(&mut self) { + let mut ctx = InferContext::new(self.source); + + let defs: Vec<_> = self.ast.defs().collect(); + let last_idx = defs.len().saturating_sub(1); + + for (idx, def) in defs.iter().enumerate() { + let is_last = idx == last_idx; + ctx.infer_def(def, is_last); + } + + ctx.mark_cyclic_types(); + + self.type_table = ctx.table; + self.type_diagnostics = ctx.diagnostics; + } +} + +struct InferContext<'src> { + source: &'src str, + table: TypeTable<'src>, + diagnostics: crate::diagnostics::Diagnostics, +} + +impl<'src> InferContext<'src> { + fn new(source: &'src str) -> Self { + Self { + source, + table: TypeTable::new(), + diagnostics: crate::diagnostics::Diagnostics::new(), + } + } + + /// Mark types that contain cyclic references (need Box/Rc/Arc in Rust). + /// Only struct/union types are marked - wrapper types (Optional, List, etc.) + /// shouldn't be wrapped in Box themselves, only their inner references. + fn mark_cyclic_types(&mut self) { + let keys: Vec<_> = self + .table + .types + .keys() + .filter(|k| !k.is_builtin()) + .filter(|k| { + matches!( + self.table.get(k), + Some(TypeValue::Struct(_)) | Some(TypeValue::TaggedUnion(_)) + ) + }) + .cloned() + .collect(); + + for key in keys { + if self.type_references_itself(&key) { + self.table.mark_cyclic(key); + } + } + } + + /// Check if a type contains a reference to itself (directly or indirectly). + fn type_references_itself(&self, key: &TypeKey<'src>) -> bool { + let mut visited = IndexSet::new(); + self.type_reaches(key, key, &mut visited) + } + + /// Check if `current` type can reach `target` type through references. + fn type_reaches( + &self, + current: &TypeKey<'src>, + target: &TypeKey<'src>, + visited: &mut IndexSet>, + ) -> bool { + if !visited.insert(current.clone()) { + return false; + } + + let Some(value) = self.table.get(current) else { + return false; + }; + + match value { + TypeValue::Struct(fields) => { + for field_key in fields.values() { + if field_key == target { + return true; + } + if self.type_reaches(field_key, target, visited) { + return true; + } + } + false + } + TypeValue::TaggedUnion(variants) => { + for variant_key in variants.values() { + if variant_key == target { + return true; + } + if self.type_reaches(variant_key, target, visited) { + return true; + } + } + false + } + TypeValue::Optional(inner) + | TypeValue::List(inner) + | TypeValue::NonEmptyList(inner) => { + if inner == target { + return true; + } + self.type_reaches(inner, target, visited) + } + TypeValue::Node | TypeValue::String | TypeValue::Unit | TypeValue::Invalid => false, + } + } + + fn infer_def(&mut self, def: &ast::Def, is_last: bool) { + let key = match def.name() { + Some(name_tok) => { + let name = token_src(&name_tok, self.source); + TypeKey::Named(name) + } + None if is_last => TypeKey::DefaultQuery, + None => return, // unnamed non-last def, already reported by earlier pass + }; + + let Some(body) = def.body() else { + return; + }; + + // Special case: tagged alternation at def level produces TaggedUnion directly + if let Expr::AltExpr(alt) = &body + && matches!(alt.kind(), AltKind::Tagged) + { + let type_annotation = match &key { + TypeKey::Named(name) => Some(*name), + _ => None, + }; + self.infer_tagged_alt(alt, &key, type_annotation); + return; + } + + let mut fields = IndexMap::new(); + self.infer_expr(&body, &key, &mut fields); + + let value = if fields.is_empty() { + TypeValue::Unit + } else { + TypeValue::Struct(Self::extract_types(fields)) + }; + + self.table.insert(key, value); + } + + /// Extract just the types from field entries + fn extract_types( + fields: IndexMap<&'src str, FieldEntry<'src>>, + ) -> IndexMap<&'src str, TypeKey<'src>> { + fields.into_iter().map(|(k, v)| (k, v.type_key)).collect() + } + + /// Extract types by reference for merge operations + fn extract_types_ref( + fields: &IndexMap<&'src str, FieldEntry<'src>>, + ) -> IndexMap<&'src str, TypeKey<'src>> { + fields + .iter() + .map(|(k, v)| (*k, v.type_key.clone())) + .collect() + } + + /// Infer type for an expression, collecting captures into `fields`. + /// Returns the TypeKey if this expression produces a referenceable type. + fn infer_expr( + &mut self, + expr: &Expr, + parent: &TypeKey<'src>, + fields: &mut IndexMap<&'src str, FieldEntry<'src>>, + ) -> Option> { + match expr { + Expr::NamedNode(node) => { + for child in node.children() { + self.infer_expr(&child, parent, fields); + } + Some(TypeKey::Node) + } + + Expr::AnonymousNode(_) => Some(TypeKey::Node), + + Expr::Ref(r) => { + let name_tok = r.name()?; + let name = token_src(&name_tok, self.source); + Some(TypeKey::Named(name)) + } + + Expr::SeqExpr(seq) => { + for child in seq.children() { + self.infer_expr(&child, parent, fields); + } + None + } + + Expr::FieldExpr(field) => { + if let Some(value) = field.value() { + self.infer_expr(&value, parent, fields); + } + None + } + + Expr::CapturedExpr(cap) => self.infer_capture(cap, parent, fields), + + Expr::QuantifiedExpr(quant) => self.infer_quantified(quant, parent, fields), + + Expr::AltExpr(alt) => self.infer_alt(alt, parent, fields), + } + } + + fn infer_capture( + &mut self, + cap: &ast::CapturedExpr, + parent: &TypeKey<'src>, + fields: &mut IndexMap<&'src str, FieldEntry<'src>>, + ) -> Option> { + let name_tok = cap.name()?; + let capture_name = token_src(&name_tok, self.source); + let capture_range = name_tok.text_range(); + + let type_annotation = cap.type_annotation().and_then(|t| { + let tok = t.name()?; + Some(token_src(&tok, self.source)) + }); + + let inner = cap.inner(); + + // Flat extraction: collect nested captures from inner expression into outer fields + // Only for NamedNode/AnonymousNode - Seq/Alt create their own scopes when captured + if let Some(ref inner_expr) = inner { + match inner_expr { + Expr::NamedNode(node) => { + for child in node.children() { + self.infer_expr(&child, parent, fields); + } + } + Expr::FieldExpr(field) => { + if let Some(value) = field.value() { + self.infer_expr(&value, parent, fields); + } + } + _ => {} + } + } + + let inner_type = + self.infer_capture_inner(inner.as_ref(), parent, capture_name, type_annotation); + + // Check for duplicate capture in scope + // Unlike alternations (where branches are mutually exclusive), + // in sequences both captures execute - can't have two values for same name + if let Some(existing) = fields.get(capture_name) { + self.diagnostics + .report(DiagnosticKind::DuplicateCaptureInScope, capture_range) + .message(capture_name) + .related_to("first use", existing.capture_range) + .emit(); + fields.insert( + capture_name, + FieldEntry { + type_key: TypeKey::Invalid, + capture_range, + }, + ); + return Some(TypeKey::Invalid); + } + + fields.insert( + capture_name, + FieldEntry { + type_key: inner_type.clone(), + capture_range, + }, + ); + Some(inner_type) + } + + fn infer_capture_inner( + &mut self, + inner: Option<&Expr>, + parent: &TypeKey<'src>, + capture_name: &'src str, + type_annotation: Option<&'src str>, + ) -> TypeKey<'src> { + // Handle quantifier first - it wraps whatever the inner type is + // This ensures `(x)+ @name :: string` becomes Vec, not String + if let Some(Expr::QuantifiedExpr(q)) = inner { + let Some(qinner) = q.inner() else { + return TypeKey::Invalid; + }; + let inner_key = + self.infer_capture_inner(Some(&qinner), parent, capture_name, type_annotation); + return self.wrap_with_quantifier(&inner_key, q, parent, capture_name); + } + + // :: string annotation + if type_annotation == Some("string") { + return TypeKey::String; + } + + let Some(inner) = inner else { + return type_annotation.map(TypeKey::Named).unwrap_or(TypeKey::Node); + }; + + match inner { + Expr::Ref(r) => { + if let Some(name_tok) = r.name() { + let ref_name = token_src(&name_tok, self.source); + TypeKey::Named(ref_name) + } else { + TypeKey::Invalid + } + } + + Expr::SeqExpr(_) => { + self.infer_nested_scope(inner, parent, capture_name, type_annotation, || { + inner.children().into_iter().collect() + }) + } + + Expr::AltExpr(alt) => { + self.infer_nested_scope(inner, parent, capture_name, type_annotation, || { + alt.branches().filter_map(|b| b.body()).collect() + }) + } + + Expr::NamedNode(_) | Expr::AnonymousNode(_) => { + type_annotation.map(TypeKey::Named).unwrap_or(TypeKey::Node) + } + + Expr::QuantifiedExpr(_) => { + unreachable!("quantifier handled at start of function") + } + + Expr::FieldExpr(field) => { + if let Some(value) = field.value() { + self.infer_capture_inner(Some(&value), parent, capture_name, type_annotation) + } else { + type_annotation.map(TypeKey::Named).unwrap_or(TypeKey::Node) + } + } + + Expr::CapturedExpr(_) => type_annotation.map(TypeKey::Named).unwrap_or(TypeKey::Node), + } + } + + fn infer_nested_scope( + &mut self, + inner: &Expr, + parent: &TypeKey<'src>, + capture_name: &'src str, + type_annotation: Option<&'src str>, + get_children: F, + ) -> TypeKey<'src> + where + F: FnOnce() -> Vec, + { + let nested_parent = TypeKey::Synthetic { + parent: Box::new(parent.clone()), + name: capture_name, + }; + + let mut nested_fields = IndexMap::new(); + + match inner { + Expr::AltExpr(alt) => { + let alt_key = self.infer_alt_as_type(alt, &nested_parent, type_annotation); + return alt_key; + } + _ => { + for child in get_children() { + self.infer_expr(&child, &nested_parent, &mut nested_fields); + } + } + } + + if nested_fields.is_empty() { + return type_annotation.map(TypeKey::Named).unwrap_or(TypeKey::Node); + } + + let key = if let Some(name) = type_annotation { + TypeKey::Named(name) + } else { + nested_parent.clone() + }; + + self.table.insert( + key.clone(), + TypeValue::Struct(Self::extract_types(nested_fields)), + ); + key + } + + fn infer_quantified( + &mut self, + quant: &ast::QuantifiedExpr, + parent: &TypeKey<'src>, + fields: &mut IndexMap<&'src str, FieldEntry<'src>>, + ) -> Option> { + let inner = quant.inner()?; + quant.operator()?; + + // If the inner is a capture, we need special handling for the wrapper + if let Expr::CapturedExpr(cap) = &inner { + let name_tok = cap.name()?; + let capture_name = token_src(&name_tok, self.source); + let capture_range = name_tok.text_range(); + + let type_annotation = cap.type_annotation().and_then(|t| { + let tok = t.name()?; + Some(token_src(&tok, self.source)) + }); + + let inner_key = self.infer_capture_inner( + cap.inner().as_ref(), + parent, + capture_name, + type_annotation, + ); + let wrapped_key = self.wrap_with_quantifier(&inner_key, quant, parent, capture_name); + + fields.insert( + capture_name, + FieldEntry { + type_key: wrapped_key.clone(), + capture_range, + }, + ); + return Some(wrapped_key); + } + + // Non-capture quantified expression: track fields added by inner expression + // and wrap them with the quantifier + let fields_before: Vec<_> = fields.keys().copied().collect(); + + let inner_key = self.infer_expr(&inner, parent, fields)?; + + // Wrap all newly added fields with the quantifier + let field_names: Vec<_> = fields.keys().copied().collect(); + for name in field_names { + if fields_before.contains(&name) { + continue; + } + if let Some(entry) = fields.get_mut(name) { + entry.type_key = self.wrap_with_quantifier(&entry.type_key, quant, parent, name); + } + } + + // Return wrapped inner key (though typically unused when wrapping field captures) + Some(inner_key) + } + + fn wrap_with_quantifier( + &mut self, + inner: &TypeKey<'src>, + quant: &ast::QuantifiedExpr, + parent: &TypeKey<'src>, + capture_name: &'src str, + ) -> TypeKey<'src> { + if matches!(inner, TypeKey::Invalid) { + return TypeKey::Invalid; + } + + // Check list/non-empty-list before optional since * matches both is_list() and is_optional() + let wrapper = if quant.is_list() { + TypeValue::List(inner.clone()) + } else if quant.is_non_empty_list() { + TypeValue::NonEmptyList(inner.clone()) + } else if quant.is_optional() { + TypeValue::Optional(inner.clone()) + } else { + return inner.clone(); + }; + + // Synthetic key: Parent + capture_name → e.g., QueryResultItems + let wrapper_key = TypeKey::Synthetic { + parent: Box::new(parent.clone()), + name: capture_name, + }; + + self.table.insert(wrapper_key.clone(), wrapper); + wrapper_key + } + + fn infer_alt( + &mut self, + alt: &ast::AltExpr, + parent: &TypeKey<'src>, + fields: &mut IndexMap<&'src str, FieldEntry<'src>>, + ) -> Option> { + // Alt without capture: just collect fields from all branches into current scope + match alt.kind() { + AltKind::Tagged => { + // Tagged alt without capture: unusual, but collect fields + for branch in alt.branches() { + if let Some(body) = branch.body() { + self.infer_expr(&body, parent, fields); + } + } + } + AltKind::Untagged | AltKind::Mixed => { + // Untagged alt: merge fields from branches + let branch_fields = self.collect_branch_fields(alt, parent); + let branch_types: Vec<_> = + branch_fields.iter().map(Self::extract_types_ref).collect(); + let merged = self.table.merge_fields(&branch_types); + self.apply_merged_fields(merged, fields, alt, parent); + } + } + None + } + + fn infer_alt_as_type( + &mut self, + alt: &ast::AltExpr, + parent: &TypeKey<'src>, + type_annotation: Option<&'src str>, + ) -> TypeKey<'src> { + match alt.kind() { + AltKind::Tagged => self.infer_tagged_alt(alt, parent, type_annotation), + AltKind::Untagged | AltKind::Mixed => { + self.infer_untagged_alt(alt, parent, type_annotation) + } + } + } + + fn infer_tagged_alt( + &mut self, + alt: &ast::AltExpr, + parent: &TypeKey<'src>, + type_annotation: Option<&'src str>, + ) -> TypeKey<'src> { + let mut variants = IndexMap::new(); + + for branch in alt.branches() { + let Some(label_tok) = branch.label() else { + continue; + }; + let label = token_src(&label_tok, self.source); + + let variant_key = TypeKey::Synthetic { + parent: Box::new(parent.clone()), + name: label, + }; + + let mut variant_fields = IndexMap::new(); + let body_type = if let Some(body) = branch.body() { + self.infer_expr(&body, &variant_key, &mut variant_fields) + } else { + None + }; + + let variant_value = if variant_fields.is_empty() { + // No captures: check if the body produced a meaningful type + match body_type { + Some(key) if !key.is_builtin() => { + // Branch body has a non-builtin type (e.g., Ref or wrapped type) + // Create a struct with a "value" field + let mut fields = IndexMap::new(); + fields.insert("value", key); + TypeValue::Struct(fields) + } + _ => TypeValue::Unit, + } + } else { + TypeValue::Struct(Self::extract_types(variant_fields)) + }; + + // Variant types shouldn't conflict - they have unique paths including the label + self.table.insert(variant_key.clone(), variant_value); + variants.insert(label, variant_key); + } + + let union_key = if let Some(name) = type_annotation { + TypeKey::Named(name) + } else { + parent.clone() + }; + + // Detect conflict: same key with incompatible TaggedUnion + let current_span = alt.text_range(); + if let Err(existing_span) = self.table.try_insert( + union_key.clone(), + TypeValue::TaggedUnion(variants), + current_span, + ) { + self.diagnostics + .report( + DiagnosticKind::IncompatibleTaggedAlternations, + existing_span, + ) + .related_to("incompatible", current_span) + .emit(); + return union_key; + } + union_key + } + + fn infer_untagged_alt( + &mut self, + alt: &ast::AltExpr, + parent: &TypeKey<'src>, + type_annotation: Option<&'src str>, + ) -> TypeKey<'src> { + let branch_fields = self.collect_branch_fields(alt, parent); + let branch_types: Vec<_> = branch_fields.iter().map(Self::extract_types_ref).collect(); + let merged = self.table.merge_fields(&branch_types); + + if merged.is_empty() { + return type_annotation.map(TypeKey::Named).unwrap_or(TypeKey::Node); + } + + let mut result_fields = IndexMap::new(); + self.apply_merged_fields(merged, &mut result_fields, alt, parent); + + let key = if let Some(name) = type_annotation { + TypeKey::Named(name) + } else { + parent.clone() + }; + + self.table.insert( + key.clone(), + TypeValue::Struct(Self::extract_types(result_fields)), + ); + key + } + + fn collect_branch_fields( + &mut self, + alt: &ast::AltExpr, + parent: &TypeKey<'src>, + ) -> Vec>> { + let mut branch_fields = Vec::new(); + + for branch in alt.branches() { + let mut fields = IndexMap::new(); + if let Some(body) = branch.body() { + self.infer_expr(&body, parent, &mut fields); + } + branch_fields.push(fields); + } + + branch_fields + } + + fn apply_merged_fields( + &mut self, + merged: IndexMap<&'src str, MergedField<'src>>, + result_fields: &mut IndexMap<&'src str, FieldEntry<'src>>, + alt: &ast::AltExpr, + parent: &TypeKey<'src>, + ) { + for (name, merge_result) in merged { + let key = match merge_result { + MergedField::Same(k) => k, + MergedField::Optional(k) => { + let wrapper_key = TypeKey::Synthetic { + parent: Box::new(TypeKey::Synthetic { + parent: Box::new(parent.clone()), + name, + }), + name: "opt", + }; + self.table + .insert(wrapper_key.clone(), TypeValue::Optional(k)); + wrapper_key + } + MergedField::Conflict => { + self.diagnostics + .report(DiagnosticKind::TypeConflictInMerge, alt.text_range()) + .message(name) + .emit(); + TypeKey::Invalid + } + }; + result_fields.insert( + name, + FieldEntry { + type_key: key, + // Use the alt's range as a fallback since we don't have individual capture ranges here + capture_range: alt.text_range(), + }, + ); + } + } +} diff --git a/crates/plotnik-lib/src/query/types_tests.rs b/crates/plotnik-lib/src/query/types_tests.rs new file mode 100644 index 00000000..e095c124 --- /dev/null +++ b/crates/plotnik-lib/src/query/types_tests.rs @@ -0,0 +1,591 @@ +//! Type inference tests. + +use crate::Query; +use indoc::indoc; + +#[test] +fn capture_node_produces_node_field() { + let query = Query::try_from("(identifier) @id").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { id: SyntaxNode };"); +} + +#[test] +fn multiple_captures_produce_multiple_fields() { + let query = Query::try_from("(binary left: (_) @left right: (_) @right)").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { left: SyntaxNode; right: SyntaxNode };"); +} + +#[test] +fn no_captures_produces_unit() { + let query = Query::try_from("(identifier)").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @""); +} + +#[test] +fn nested_capture_flattens() { + let query = Query::try_from("(function name: (identifier) @name)").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { name: SyntaxNode };"); +} + +#[test] +fn string_annotation() { + let query = Query::try_from("(identifier) @name :: string").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { name: string };"); +} + +#[test] +fn named_type_annotation() { + let query = Query::try_from("(identifier) @value :: MyType").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { value: MyType };"); +} + +#[test] +fn annotation_on_quantified_wraps_inner() { + let query = Query::try_from("(identifier)+ @names :: string").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultNames = [string, ...string[]]; + + type QueryResult = { names: [string, ...string[]] }; + "); +} + +#[test] +fn capture_ref_produces_ref_type() { + let input = indoc! {r#" + Inner = (identifier) @name + (wrapper (Inner) @inner) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type Inner = { name: SyntaxNode }; + + type QueryResult = { inner: Inner }; + "); +} + +#[test] +fn ref_without_capture_contributes_nothing() { + let input = indoc! {r#" + Inner = (identifier) @name + (wrapper (Inner)) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type Inner = { name: SyntaxNode };"); +} + +#[test] +fn optional_node() { + let query = Query::try_from("(identifier)? @maybe").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultMaybe = SyntaxNode; + + type QueryResult = { maybe?: SyntaxNode }; + "); +} + +#[test] +fn list_of_nodes() { + let query = Query::try_from("(identifier)* @items").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultItems = SyntaxNode[]; + + type QueryResult = { items: SyntaxNode[] }; + "); +} + +#[test] +fn nonempty_list_of_nodes() { + let query = Query::try_from("(identifier)+ @items").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultItems = [SyntaxNode, ...SyntaxNode[]]; + + type QueryResult = { items: [SyntaxNode, ...SyntaxNode[]] }; + "); +} + +#[test] +fn quantified_ref() { + let input = indoc! {r#" + Item = (item) @value + (container (Item)+ @items) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type Item = { value: SyntaxNode }; + + type QueryResultItems = [Item, ...Item[]]; + + type QueryResult = { items: [Item, ...Item[]] }; + "); +} + +#[test] +fn quantifier_outside_capture() { + let query = Query::try_from("((identifier) @id)*").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultId = SyntaxNode[]; + + type QueryResult = { id: SyntaxNode[] }; + "); +} + +#[test] +fn captured_seq_creates_nested_struct() { + let query = Query::try_from("{(a) @x (b) @y} @pair").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultPair = { x: SyntaxNode; y: SyntaxNode }; + + type QueryResult = { pair: QueryResultPair }; + "); +} + +#[test] +fn captured_seq_in_tree() { + let input = indoc! {r#" + (function + {(param) @p} @params + (body) @body) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultParams = { p: SyntaxNode }; + + type QueryResult = { params: QueryResultParams; body: SyntaxNode }; + "); +} + +#[test] +fn empty_captured_seq_is_node() { + let query = Query::try_from("{} @empty").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { empty: SyntaxNode };"); +} + +#[test] +fn tagged_alt_produces_union() { + let input = "[A: (a) @x B: (b) @y]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r#" + type QueryResultA = { x: SyntaxNode }; + + type QueryResultB = { y: SyntaxNode }; + + type QueryResult = + | { tag: "A"; x: SyntaxNode } + | { tag: "B"; y: SyntaxNode }; + "#); +} + +#[test] +fn tagged_alt_as_definition() { + let input = indoc! {r#" + Expr = [ + Binary: (binary left: (_) @left right: (_) @right) + Literal: (number) @value + ] + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r#" + type ExprBinary = { left: SyntaxNode; right: SyntaxNode }; + + type ExprLiteral = { value: SyntaxNode }; + + type Expr = + | { tag: "Binary"; left: SyntaxNode; right: SyntaxNode } + | { tag: "Literal"; value: SyntaxNode }; + "#); +} + +#[test] +fn tagged_branch_without_captures_is_unit() { + let input = "[A: (a) B: (b)]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r#" + type QueryResult = + | { tag: "A" } + | { tag: "B" }; + "#); +} + +#[test] +fn tagged_branch_with_ref() { + let input = indoc! {r#" + Rec = [Base: (a) Nested: (Rec)?] @value + (Rec) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r#" + type RecValue = + | { tag: "Base" } + | { tag: "Nested"; value: Rec }; + + type Rec = { value: RecValue }; + + type RecValueNested = { value: Rec }; + "#); +} + +#[test] +fn captured_tagged_alt() { + let input = "(container [A: (a) B: (b)] @choice)"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r#" + type QueryResultChoice = + | { tag: "A" } + | { tag: "B" }; + + type QueryResult = { choice: QueryResultChoice }; + "#); +} + +#[test] +fn untagged_alt_same_capture_merges() { + let input = "[(a) @x (b) @x]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { x: SyntaxNode };"); +} + +#[test] +fn untagged_alt_different_captures_becomes_optional() { + let input = "[(a) @x (b) @y]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultXOpt = SyntaxNode; + + type QueryResultYOpt = SyntaxNode; + + type QueryResult = { x?: SyntaxNode; y?: SyntaxNode }; + "); +} + +#[test] +fn untagged_alt_nested_alt_merges() { + let input = "[(a) @x (b) @y [(c) @x (d) @y]]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultXOpt = SyntaxNode; + + type QueryResultYOpt = SyntaxNode; + + type QueryResult = { x?: SyntaxNode; y?: SyntaxNode }; + "); +} + +#[test] +fn captured_untagged_alt_with_nested_fields() { + let input = "[{(a) @x} {(b) @y}] @choice"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultChoiceXOpt = SyntaxNode; + + type QueryResultChoiceYOpt = SyntaxNode; + + type QueryResultChoice = { x?: SyntaxNode; y?: SyntaxNode }; + + type QueryResult = { choice: QueryResultChoice }; + "); +} + +#[test] +fn merge_same_type_unchanged() { + let input = "[(identifier) @x (identifier) @x]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { x: SyntaxNode };"); +} + +#[test] +fn merge_absent_field_becomes_optional() { + let input = "[(identifier) @x (number)]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultXOpt = SyntaxNode; + + type QueryResult = { x?: SyntaxNode }; + "); +} + +#[test] +fn merge_list_and_nonempty_list_to_list() { + let input = "[(a)* @x (b)+ @x]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultX = [SyntaxNode, ...SyntaxNode[]]; + + type QueryResult = { x: [SyntaxNode, ...SyntaxNode[]] }; + "); +} + +#[test] +fn merge_optional_and_required_to_optional() { + let input = "[(a)? @x (b) @x]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultX = SyntaxNode; + + type QueryResult = { x?: SyntaxNode }; + "); +} + +#[test] +fn self_recursive_type_marked_cyclic() { + let input = "Expr = [(identifier) (call (Expr) @callee)]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type Expr = { callee?: Expr }; + + type ExprCalleeOpt = Expr; + "); +} + +#[test] +fn recursive_through_optional() { + let input = indoc! {r#" + Rec = (call function: (Rec)? @inner) + (Rec) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type Rec = { inner?: Rec }; + + type RecInner = Rec; + "); +} + +#[test] +fn recursive_in_tagged_alt() { + let input = indoc! {r#" + Expr = [ + Ident: (identifier) @name + Call: (call function: (Expr) @func) + ] + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r#" + type ExprIdent = { name: SyntaxNode }; + + type Expr = + | { tag: "Ident"; name: SyntaxNode } + | { tag: "Call"; func: Expr }; + + type ExprCall = { func: Expr }; + "#); +} + +#[test] +fn unnamed_last_def_is_default_query() { + let input = "(program (identifier)* @items)"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultItems = SyntaxNode[]; + + type QueryResult = { items: SyntaxNode[] }; + "); +} + +#[test] +fn named_defs_plus_entry_point() { + let input = indoc! {r#" + Item = (item) @value + (container (Item)* @items) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type Item = { value: SyntaxNode }; + + type QueryResultItems = Item[]; + + type QueryResult = { items: Item[] }; + "); +} + +#[test] +fn tagged_alt_at_entry_point() { + let input = "[A: (a) @x B: (b) @y]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r#" + type QueryResultA = { x: SyntaxNode }; + + type QueryResultB = { y: SyntaxNode }; + + type QueryResult = + | { tag: "A"; x: SyntaxNode } + | { tag: "B"; y: SyntaxNode }; + "#); +} + +#[test] +fn type_conflict_in_untagged_alt() { + let input = "[(identifier) @x :: string (number) @x]"; + let query = Query::try_from(input).unwrap(); + assert!(!query.is_valid()); + insta::assert_snapshot!(query.dump_diagnostics(), @r" + error: capture `x` has conflicting types across branches + | + 1 | [(identifier) @x :: string (number) @x] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + "); +} + +#[test] +fn incompatible_tagged_alts_in_merge() { + let input = "[[A: (a) @x] @y [B: (b) @z] @y]"; + let query = Query::try_from(input).unwrap(); + assert!(!query.is_valid()); + insta::assert_snapshot!(query.dump_diagnostics(), @r" + error: tagged alternations with different variants cannot be merged + | + 1 | [[A: (a) @x] @y [B: (b) @z] @y] + | ^^^^^^^^^^^ ----------- incompatible + "); +} + +#[test] +fn duplicate_capture_in_sequence() { + let input = "{(a) @x (b) @x}"; + let query = Query::try_from(input).unwrap(); + assert!(!query.is_valid()); + insta::assert_snapshot!(query.dump_diagnostics(), @r" + error: capture `@x` already used in this scope + | + 1 | {(a) @x (b) @x} + | - ^ + | | + | first use + "); +} + +#[test] +fn duplicate_capture_nested() { + let input = "(foo (a) @x (bar (b) @x))"; + let query = Query::try_from(input).unwrap(); + assert!(!query.is_valid()); + insta::assert_snapshot!(query.dump_diagnostics(), @r" + error: capture `@x` already used in this scope + | + 1 | (foo (a) @x (bar (b) @x)) + | - first use ^ + "); +} + +#[test] +fn wildcard_capture() { + let query = Query::try_from("(_) @node").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { node: SyntaxNode };"); +} + +#[test] +fn anonymous_node_capture() { + let query = Query::try_from("_ @anon").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { anon: SyntaxNode };"); +} + +#[test] +fn string_literal_capture() { + let query = Query::try_from(r#""if" @kw"#).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { kw: SyntaxNode };"); +} + +#[test] +fn field_value_capture() { + let query = Query::try_from("(call name: (identifier) @name)").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @"type QueryResult = { name: SyntaxNode };"); +} + +#[test] +fn deeply_nested_seq() { + let query = Query::try_from("{{{(identifier) @x}}} @outer").unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type QueryResultOuter = { x: SyntaxNode }; + + type QueryResult = { outer: QueryResultOuter }; + "); +} + +#[test] +fn same_tag_in_branches_merges() { + let input = "[[A: (a)] @x [A: (b)] @x]"; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r#" + type QueryResultX = + | { tag: "A" }; + + type QueryResult = { x: QueryResultX }; + "#); +} + +#[test] +fn annotation_on_captured_ref() { + let input = indoc! {r#" + Inner = (identifier) @name + (wrapper (Inner) @inner :: CustomType) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type Inner = { name: SyntaxNode }; + + type QueryResult = { inner: Inner }; + "); +} + +#[test] +fn multiple_defs_with_refs() { + let input = indoc! {r#" + A = (a) @x + B = (b (A) @a) + C = (c (B) @b) + (root (C) @c) + "#}; + let query = Query::try_from(input).unwrap(); + assert!(query.is_valid()); + insta::assert_snapshot!(query.dump_types(), @r" + type A = { x: SyntaxNode }; + + type B = { a: A }; + + type C = { b: B }; + + type QueryResult = { c: C }; + "); +}