From b31fb448d5dd3f4d663f1b224f7c2c3dfaaf5eca Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Tue, 23 Dec 2025 14:03:19 -0300 Subject: [PATCH] feat: Enforce node type ids as NonZeroU16 --- crates/plotnik-core/src/lib.rs | 196 +------------------------------ crates/plotnik-langs/Cargo.toml | 2 +- crates/plotnik-langs/src/lib.rs | 60 +++------- crates/plotnik-macros/src/lib.rs | 52 +++++--- 4 files changed, 50 insertions(+), 260 deletions(-) diff --git a/crates/plotnik-core/src/lib.rs b/crates/plotnik-core/src/lib.rs index 06bd615c..dd63cffb 100644 --- a/crates/plotnik-core/src/lib.rs +++ b/crates/plotnik-core/src/lib.rs @@ -55,8 +55,8 @@ pub fn parse_node_types(json: &str) -> Result, serde_json::Error> { serde_json::from_str(json) } -/// Node type ID (tree-sitter uses u16). -pub type NodeTypeId = u16; +/// Node type ID (tree-sitter uses u16, but 0 is internal-only). +pub type NodeTypeId = NonZeroU16; /// Field ID (tree-sitter uses NonZeroU16). pub type NodeFieldId = NonZeroU16; @@ -623,196 +623,4 @@ mod tests { let plus = nodes.iter().find(|n| n.type_name == "+").unwrap(); assert!(!plus.named); } - - #[test] - fn build_dynamic_node_types() { - let raw = parse_node_types(SAMPLE_JSON).unwrap(); - - let node_ids: HashMap<(&str, bool), NodeTypeId> = [ - (("expression", true), 1), - (("function_declaration", true), 2), - (("program", true), 3), - (("comment", true), 4), - (("identifier", true), 5), - (("+", false), 6), - (("block", true), 7), - (("statement", true), 8), - (("number", true), 9), - ] - .into_iter() - .collect(); - - let field_ids: HashMap<&str, NodeFieldId> = [ - ("name", NonZeroU16::new(1).unwrap()), - ("body", NonZeroU16::new(2).unwrap()), - ] - .into_iter() - .collect(); - - let node_types = DynamicNodeTypes::build( - &raw, - |name, named| node_ids.get(&(name, named)).copied(), - |name| field_ids.get(name).copied(), - ); - - assert_eq!(node_types.len(), 6); - - // Test via trait - assert_eq!(node_types.root(), Some(3)); - assert!(node_types.is_extra(4)); - assert!(!node_types.is_extra(5)); - assert!(node_types.has_field(2, NonZeroU16::new(1).unwrap())); - assert!(node_types.has_field(2, NonZeroU16::new(2).unwrap())); - assert!(!node_types.has_field(2, NonZeroU16::new(99).unwrap())); - assert!(node_types.is_valid_field_type(2, NonZeroU16::new(1).unwrap(), 5)); - assert!(!node_types.is_valid_field_type(2, NonZeroU16::new(1).unwrap(), 7)); - } - - // Static tests using manually constructed data - static TEST_VALID_TYPES_ID: [NodeTypeId; 1] = [5]; // identifier - static TEST_VALID_TYPES_BLOCK: [NodeTypeId; 1] = [7]; // block - static TEST_CHILDREN_TYPES: [NodeTypeId; 1] = [8]; // statement - - static TEST_FIELDS: [(NodeFieldId, StaticFieldInfo); 2] = [ - ( - NonZeroU16::new(1).unwrap(), - StaticFieldInfo { - cardinality: Cardinality { - multiple: false, - required: true, - }, - valid_types: &TEST_VALID_TYPES_ID, - }, - ), - ( - NonZeroU16::new(2).unwrap(), - StaticFieldInfo { - cardinality: Cardinality { - multiple: false, - required: true, - }, - valid_types: &TEST_VALID_TYPES_BLOCK, - }, - ), - ]; - - static TEST_NODES: [(NodeTypeId, StaticNodeTypeInfo); 4] = [ - ( - 1, - StaticNodeTypeInfo { - name: "expression", - named: true, - fields: &[], - children: None, - }, - ), - ( - 2, - StaticNodeTypeInfo { - name: "function_declaration", - named: true, - fields: &TEST_FIELDS, - children: None, - }, - ), - ( - 3, - StaticNodeTypeInfo { - name: "program", - named: true, - fields: &[], - children: Some(StaticChildrenInfo { - cardinality: Cardinality { - multiple: true, - required: false, - }, - valid_types: &TEST_CHILDREN_TYPES, - }), - }, - ), - ( - 4, - StaticNodeTypeInfo { - name: "comment", - named: true, - fields: &[], - children: None, - }, - ), - ]; - - static TEST_EXTRAS: [NodeTypeId; 1] = [4]; - - static TEST_STATIC_NODE_TYPES: StaticNodeTypes = - StaticNodeTypes::new(&TEST_NODES, &TEST_EXTRAS, Some(3)); - - #[test] - fn static_node_types_get() { - let info = TEST_STATIC_NODE_TYPES.get(2).unwrap(); - assert_eq!(info.name, "function_declaration"); - assert!(info.named); - - assert!(TEST_STATIC_NODE_TYPES.get(99).is_none()); - } - - #[test] - fn static_node_types_contains() { - assert!(TEST_STATIC_NODE_TYPES.contains(1)); - assert!(TEST_STATIC_NODE_TYPES.contains(2)); - assert!(!TEST_STATIC_NODE_TYPES.contains(99)); - } - - #[test] - fn static_node_types_trait() { - // Test via trait methods - assert_eq!(TEST_STATIC_NODE_TYPES.root(), Some(3)); - assert!(TEST_STATIC_NODE_TYPES.is_extra(4)); - assert!(!TEST_STATIC_NODE_TYPES.is_extra(1)); - - assert!(TEST_STATIC_NODE_TYPES.has_field(2, NonZeroU16::new(1).unwrap())); - assert!(TEST_STATIC_NODE_TYPES.has_field(2, NonZeroU16::new(2).unwrap())); - assert!(!TEST_STATIC_NODE_TYPES.has_field(2, NonZeroU16::new(99).unwrap())); - assert!(!TEST_STATIC_NODE_TYPES.has_field(1, NonZeroU16::new(1).unwrap())); - - assert!(TEST_STATIC_NODE_TYPES.is_valid_field_type(2, NonZeroU16::new(1).unwrap(), 5)); - assert!(!TEST_STATIC_NODE_TYPES.is_valid_field_type(2, NonZeroU16::new(1).unwrap(), 7)); - assert!(TEST_STATIC_NODE_TYPES.is_valid_field_type(2, NonZeroU16::new(2).unwrap(), 7)); - - let field_types = TEST_STATIC_NODE_TYPES.valid_field_types(2, NonZeroU16::new(1).unwrap()); - assert_eq!(field_types, &[5]); - - let card = TEST_STATIC_NODE_TYPES - .field_cardinality(2, NonZeroU16::new(1).unwrap()) - .unwrap(); - assert!(!card.multiple); - assert!(card.required); - } - - #[test] - fn static_node_types_children() { - let card = TEST_STATIC_NODE_TYPES.children_cardinality(3).unwrap(); - assert!(card.multiple); - assert!(!card.required); - - let child_types = TEST_STATIC_NODE_TYPES.valid_child_types(3); - assert_eq!(child_types, &[8]); - - assert!(TEST_STATIC_NODE_TYPES.is_valid_child_type(3, 8)); - assert!(!TEST_STATIC_NODE_TYPES.is_valid_child_type(3, 5)); - - assert!(TEST_STATIC_NODE_TYPES.children_cardinality(1).is_none()); - assert!(TEST_STATIC_NODE_TYPES.valid_child_types(1).is_empty()); - } - - #[test] - fn static_node_types_len() { - assert_eq!(TEST_STATIC_NODE_TYPES.len(), 4); - assert!(!TEST_STATIC_NODE_TYPES.is_empty()); - } - - #[test] - fn static_node_types_iter() { - let ids: Vec<_> = TEST_STATIC_NODE_TYPES.iter().map(|(id, _)| id).collect(); - assert_eq!(ids, vec![1, 2, 3, 4]); - } } diff --git a/crates/plotnik-langs/Cargo.toml b/crates/plotnik-langs/Cargo.toml index 272119ac..bc6437b8 100644 --- a/crates/plotnik-langs/Cargo.toml +++ b/crates/plotnik-langs/Cargo.toml @@ -332,4 +332,4 @@ arborium-zsh = { version = "2.3.2", optional = true } [build-dependencies] cargo_metadata = "0.23" -[dev-dependencies] \ No newline at end of file +[dev-dependencies] diff --git a/crates/plotnik-langs/src/lib.rs b/crates/plotnik-langs/src/lib.rs index e62ce6f8..4075c20d 100644 --- a/crates/plotnik-langs/src/lib.rs +++ b/crates/plotnik-langs/src/lib.rs @@ -1,3 +1,4 @@ +use std::num::NonZeroU16; use std::sync::Arc; use arborium_tree_sitter::Language; @@ -100,18 +101,13 @@ impl LangImpl for LangInner { fn resolve_named_node(&self, kind: &str) -> Option { let id = self.ts_lang.id_for_node_kind(kind, true); - // For named nodes, 0 always means "not found" - (id != 0).then_some(id) + NonZeroU16::new(id) } fn resolve_anonymous_node(&self, kind: &str) -> Option { let id = self.ts_lang.id_for_node_kind(kind, false); - // Tree-sitter returns 0 for both "not found" AND the valid anonymous "end" node. - // We disambiguate via reverse lookup. - if id != 0 { - return Some(id); - } - (self.ts_lang.node_kind_for_id(0) == Some(kind)).then_some(0) + // Node ID 0 is tree-sitter internal; we never obtain it via cursor walk. + NonZeroU16::new(id) } fn resolve_field(&self, name: &str) -> Option { @@ -134,7 +130,7 @@ impl LangImpl for LangInner { } fn node_type_name(&self, node_type_id: NodeTypeId) -> Option<&'static str> { - self.ts_lang.node_kind_for_id(node_type_id) + self.ts_lang.node_kind_for_id(node_type_id.get()) } fn field_name(&self, field_id: NodeFieldId) -> Option<&'static str> { @@ -156,11 +152,11 @@ impl LangImpl for LangInner { } fn is_supertype(&self, node_type_id: NodeTypeId) -> bool { - self.ts_lang.node_kind_is_supertype(node_type_id) + self.ts_lang.node_kind_is_supertype(node_type_id.get()) } fn subtypes(&self, supertype: NodeTypeId) -> &[u16] { - self.ts_lang.subtypes_for_supertype(supertype) + self.ts_lang.subtypes_for_supertype(supertype.get()) } fn root(&self) -> Option { @@ -280,7 +276,7 @@ mod tests { } #[test] - #[ignore] // TODO: wait for arborium to use ABI v15 + #[ignore] // TODO: investigate why we always obtain empty subtypes #[cfg(feature = "lang-javascript")] fn supertype_via_lang_trait() { let lang = javascript(); @@ -295,19 +291,6 @@ mod tests { assert!(!lang.is_supertype(func_id)); } - #[test] - #[cfg(feature = "lang-json")] - fn find_nonempty_subtypes() { - let lang = javascript(); - for id in 0..500u16 { - let subtypes = lang.subtypes(id); - if !subtypes.is_empty() { - let name = lang.node_type_name(id).unwrap_or("?"); - println!("id={id} name={name} subtypes={subtypes:?}"); - } - } - } - #[test] #[cfg(feature = "lang-javascript")] fn field_validation_via_trait() { @@ -357,46 +340,31 @@ mod tests { #[test] #[cfg(feature = "lang-javascript")] - fn tree_sitter_id_zero_disambiguation() { + fn resolve_nonexistent_nodes() { let lang = javascript(); - // For named nodes: 0 unambiguously means "not found" + // Non-existent nodes return None + assert!(lang.resolve_named_node("end").is_none()); assert!(lang.resolve_named_node("fake_named").is_none()); + assert!(lang.resolve_anonymous_node("totally_fake_node").is_none()); - // For anonymous nodes: we disambiguate via reverse lookup - let end_resolved = lang.resolve_anonymous_node("end"); - let fake_resolved = lang.resolve_anonymous_node("totally_fake_node"); - - assert!(end_resolved.is_some(), "Valid 'end' node should resolve"); - assert_eq!(end_resolved, Some(0), "'end' should have ID 0"); - - assert!(fake_resolved.is_none(), "Non-existent node should be None"); - - // Our wrapper preserves field cleanliness + // Field resolution assert!(lang.resolve_field("name").is_some()); assert!(lang.resolve_field("fake_field").is_none()); } /// Verifies that languages with "end" keyword assign it a non-zero ID. - /// This proves that ID 0 ("end" sentinel) is internal to tree-sitter - /// and never exposed via the Cursor API for actual syntax nodes. #[test] #[cfg(all(feature = "lang-ruby", feature = "lang-lua"))] - fn end_keyword_has_nonzero_id() { + fn end_keyword_resolves() { // Ruby has "end" keyword for blocks, methods, classes, etc. let ruby = ruby(); let ruby_end = ruby.resolve_anonymous_node("end"); assert!(ruby_end.is_some(), "Ruby should have 'end' keyword"); - assert_ne!(ruby_end, Some(0), "Ruby 'end' keyword must not be ID 0"); // Lua has "end" keyword for blocks, functions, etc. let lua = lua(); let lua_end = lua.resolve_anonymous_node("end"); assert!(lua_end.is_some(), "Lua should have 'end' keyword"); - assert_ne!(lua_end, Some(0), "Lua 'end' keyword must not be ID 0"); - - // Both languages still have internal "end" sentinel at ID 0 - assert_eq!(ruby.node_type_name(0), Some("end")); - assert_eq!(lua.node_type_name(0), Some("end")); } } diff --git a/crates/plotnik-macros/src/lib.rs b/crates/plotnik-macros/src/lib.rs index 24322690..6406d3f9 100644 --- a/crates/plotnik-macros/src/lib.rs +++ b/crates/plotnik-macros/src/lib.rs @@ -250,22 +250,24 @@ struct FieldCodeGen { fn generate_field_code( prefix: &str, - node_id: u16, + node_id: std::num::NonZeroU16, field_id: &std::num::NonZeroU16, field_info: &plotnik_core::FieldInfo, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { - let valid_types = field_info.valid_types.to_vec(); + let valid_types_raw: Vec = field_info.valid_types.iter().map(|id| id.get()).collect(); let valid_types_name = syn::Ident::new( - &format!("{}_N{}_F{}_TYPES", prefix, node_id, field_id), + &format!("{}_N{}_F{}_TYPES", prefix, node_id.get(), field_id), Span::call_site(), ); let multiple = field_info.cardinality.multiple; let required = field_info.cardinality.required; - let types_len = valid_types.len(); + let types_len = valid_types_raw.len(); let array_def = quote! { - static #valid_types_name: [u16; #types_len] = [#(#valid_types),*]; + static #valid_types_name: [std::num::NonZeroU16; #types_len] = [ + #(std::num::NonZeroU16::new(#valid_types_raw).unwrap()),* + ]; }; let field_id_raw = field_id.get(); @@ -284,7 +286,7 @@ fn generate_field_code( fn generate_fields_for_node( prefix: &str, - node_id: u16, + node_id: std::num::NonZeroU16, fields: &std::collections::HashMap, ) -> FieldCodeGen { let mut sorted_fields: Vec<_> = fields.iter().collect(); @@ -307,19 +309,21 @@ fn generate_fields_for_node( fn generate_children_code( prefix: &str, - node_id: u16, + node_id: std::num::NonZeroU16, children: &plotnik_core::ChildrenInfo, static_defs: &mut Vec, ) -> proc_macro2::TokenStream { - let valid_types = children.valid_types.to_vec(); + let valid_types_raw: Vec = children.valid_types.iter().map(|id| id.get()).collect(); let children_types_name = syn::Ident::new( - &format!("{}_N{}_CHILDREN_TYPES", prefix, node_id), + &format!("{}_N{}_CHILDREN_TYPES", prefix, node_id.get()), Span::call_site(), ); - let types_len = valid_types.len(); + let types_len = valid_types_raw.len(); static_defs.push(quote! { - static #children_types_name: [u16; #types_len] = [#(#valid_types),*]; + static #children_types_name: [std::num::NonZeroU16; #types_len] = [ + #(std::num::NonZeroU16::new(#valid_types_raw).unwrap()),* + ]; }); let multiple = children.cardinality.multiple; @@ -346,7 +350,7 @@ fn generate_static_node_types_code( raw_nodes, |name, named| { let id = ts_lang.id_for_node_kind(name, named); - if id == 0 && named { None } else { Some(id) } + std::num::NonZeroU16::new(id) }, |name| ts_lang.field_id_for_name(name), ); @@ -355,13 +359,18 @@ fn generate_static_node_types_code( let mut static_defs = Vec::new(); let mut node_entries = Vec::new(); - let extras = node_types.sorted_extras(); + let extras_raw: Vec = node_types + .sorted_extras() + .iter() + .map(|id| id.get()) + .collect(); let root = node_types.root(); let sorted_node_ids = node_types.sorted_node_ids(); for &node_id in &sorted_node_ids { let info = node_types.get(node_id).unwrap(); + let node_id_raw = node_id.get(); let field_gen = generate_fields_for_node(&prefix, node_id, &info.fields); static_defs.extend(field_gen.array_defs); @@ -369,7 +378,7 @@ fn generate_static_node_types_code( quote! { &[] } } else { let fields_array_name = syn::Ident::new( - &format!("{}_N{}_FIELDS", prefix, node_id), + &format!("{}_N{}_FIELDS", prefix, node_id_raw), Span::call_site(), ); let fields_len = field_gen.entries.len(); @@ -393,7 +402,7 @@ fn generate_static_node_types_code( let named = info.named; node_entries.push(quote! { - (#node_id, plotnik_core::StaticNodeTypeInfo { + (std::num::NonZeroU16::new(#node_id_raw).unwrap(), plotnik_core::StaticNodeTypeInfo { name: #name, named: #named, fields: #fields_ref, @@ -406,21 +415,26 @@ fn generate_static_node_types_code( let nodes_len = sorted_node_ids.len(); let extras_array_name = syn::Ident::new(&format!("{}_EXTRAS", prefix), Span::call_site()); - let extras_len = extras.len(); + let extras_len = extras_raw.len(); let root_code = match root { - Some(id) => quote! { Some(#id) }, + Some(id) => { + let id_raw = id.get(); + quote! { Some(std::num::NonZeroU16::new(#id_raw).unwrap()) } + } None => quote! { None }, }; quote! { #(#static_defs)* - static #nodes_array_name: [(u16, plotnik_core::StaticNodeTypeInfo); #nodes_len] = [ + static #nodes_array_name: [(std::num::NonZeroU16, plotnik_core::StaticNodeTypeInfo); #nodes_len] = [ #(#node_entries),* ]; - static #extras_array_name: [u16; #extras_len] = [#(#extras),*]; + static #extras_array_name: [std::num::NonZeroU16; #extras_len] = [ + #(std::num::NonZeroU16::new(#extras_raw).unwrap()),* + ]; pub static #const_name: plotnik_core::StaticNodeTypes = plotnik_core::StaticNodeTypes::new( &#nodes_array_name,