diff --git a/.gitignore b/.gitignore index c8482dc..66145e8 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,6 @@ external/ !CHANGELOG.md *.vsix editors/vscode/package-lock.json + +# Local build artifacts +/refcell_eq diff --git a/core/ast/src/builder.rs b/core/ast/src/builder.rs index 426b005..55a1e7f 100644 --- a/core/ast/src/builder.rs +++ b/core/ast/src/builder.rs @@ -84,10 +84,18 @@ use crate::{ }; use tree_sitter::Node; +#[derive(Clone, Copy, Default)] +pub(crate) struct LocationBase { + pub(crate) offset: u32, + pub(crate) line: u32, + pub(crate) column: u32, +} + pub struct Builder<'a> { arena: Arena, source_code: Vec<(Node<'a>, &'a [u8])>, errors: Vec, + location_base: LocationBase, } impl Default for Builder<'_> { @@ -103,9 +111,34 @@ impl<'a> Builder<'a> { arena: Arena::default(), source_code: Vec::new(), errors: Vec::new(), + location_base: LocationBase::default(), } } + pub(crate) fn set_location_base(&mut self, base: LocationBase) { + self.location_base = base; + } + + pub(crate) fn reset_location_base(&mut self) { + self.location_base = LocationBase::default(); + } + + pub(crate) fn take_errors(&mut self) -> Vec { + std::mem::take(&mut self.errors) + } + + pub(crate) fn next_node_id() -> u32 { + Self::get_node_id() + } + + pub(crate) fn add_node(&mut self, node: AstNode, parent_id: u32) { + self.arena.add_node(node, parent_id); + } + + pub(crate) fn into_arena(self) -> Arena { + self.arena + } + /// Adds a source code and CST to the builder. /// /// # Panics @@ -132,7 +165,7 @@ impl<'a> Builder<'a> { pub fn build_ast(&'_ mut self) -> anyhow::Result { for (root, code) in &self.source_code.clone() { let id = Self::get_node_id(); - let location = Self::get_location(root, code); + let location = self.get_location(root, code); let source = String::from_utf8_lossy(code); debug_assert!( !source.contains('\u{FFFD}'), @@ -169,7 +202,7 @@ impl<'a> Builder<'a> { Ok(self.arena.clone()) } - fn build_use_directive( + pub(crate) fn build_use_directive( &mut self, parent_id: u32, node: &Node, @@ -177,7 +210,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let mut segments = None; let mut imported_types = None; let mut from = None; @@ -228,7 +261,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); let mut definitions = Vec::new(); @@ -261,7 +294,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); let mut variants = Vec::new(); @@ -288,7 +321,12 @@ impl<'a> Builder<'a> { node } - fn build_definition(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Definition { + pub(crate) fn build_definition( + &mut self, + parent_id: u32, + node: &Node, + code: &[u8], + ) -> Definition { let kind = node.kind(); match kind { "spec_definition" => { @@ -316,7 +354,7 @@ impl<'a> Builder<'a> { _ => panic!( "Unexpected definition kind: {}, {}", node.kind(), - Self::get_location(node, code) + self.get_location(node, code) ), } } @@ -329,7 +367,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); let mut fields = Vec::new(); let mut cursor = node.walk(); @@ -365,7 +403,7 @@ impl<'a> Builder<'a> { fn build_struct_field(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let ty = self.build_type(id, &node.child_by_field_name("type").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); @@ -383,7 +421,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let ty = self.build_type(id, &node.child_by_field_name("type").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); let value = self.build_literal(id, &node.child_by_field_name("value").unwrap(), code); @@ -411,7 +449,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let mut arguments = None; let mut returns = None; let mut type_parameters = None; @@ -469,7 +507,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); let mut arguments = None; let mut returns = None; @@ -511,7 +549,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let ty = self.build_type(id, &node.child_by_field_name("type").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); let node = Rc::new(TypeDefinition::new( @@ -574,7 +612,7 @@ impl<'a> Builder<'a> { fn build_argument(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let name_node = node.child_by_field_name("name").unwrap(); let type_node = node.child_by_field_name("type").unwrap(); let ty = self.build_type(id, &type_node, code); @@ -598,7 +636,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let is_mut = node .child_by_field_name("mut") .is_some_and(|n| n.kind() == "true"); @@ -618,7 +656,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let ty = self.build_type(id, &node.child_by_field_name("type").unwrap(), code); let node = Rc::new(IgnoreArgument::new(id, location, ty)); self.arena.add_node( @@ -631,7 +669,7 @@ impl<'a> Builder<'a> { fn build_block(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> BlockType { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); match node.kind() { "assume_block" => { let statements = self.build_block_statements( @@ -697,7 +735,7 @@ impl<'a> Builder<'a> { _ => panic!( "Unexpected block type: {}, {}", node.kind(), - Self::get_location(node, code) + self.get_location(node, code) ), } } @@ -756,7 +794,7 @@ impl<'a> Builder<'a> { _ => panic!( "Unexpected statement type: {}, {}", node.kind(), - Self::get_location(node, code) + self.get_location(node, code) ), } } @@ -769,14 +807,14 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let expr_node = &node.child_by_field_name("expression"); let expression = if let Some(expr) = expr_node { self.build_expression(id, expr, code) } else { Expression::Literal(Literal::Unit(Rc::new(UnitLiteral::new( Self::get_node_id(), - Self::get_location(node, code), + self.get_location(node, code), )))) }; let node = Rc::new(ReturnStatement::new(id, location, expression)); @@ -795,7 +833,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let condition = node .child_by_field_name("condition") .map(|n| self.build_expression(id, &n, code)); @@ -810,7 +848,7 @@ impl<'a> Builder<'a> { fn build_if_statement(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let condition_node = node.child_by_field_name("condition").unwrap(); let condition = self.build_expression(id, &condition_node, code); let if_arm_node = node.child_by_field_name("if_arm").unwrap(); @@ -832,7 +870,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let ty = self.build_type(id, &node.child_by_field_name("type").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); let value = node @@ -858,7 +896,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let ty = self.build_type(id, &node.child_by_field_name("type").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); @@ -908,7 +946,7 @@ impl<'a> Builder<'a> { "identifier" => Expression::Identifier(self.build_identifier(parent_id, node, code)), _ => panic!( "Unexpected expression node kind: {node_kind} at {}", - Self::get_location(node, code) + self.get_location(node, code) ), } } @@ -921,7 +959,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let left = self.build_expression(id, &node.child_by_field_name("left").unwrap(), code); let right = self.build_expression(id, &node.child_by_field_name("right").unwrap(), code); @@ -941,7 +979,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let array = self.build_expression(id, &node.named_child(0).unwrap(), code); let index = self.build_expression(id, &node.named_child(1).unwrap(), code); @@ -961,7 +999,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let expression = self.build_expression(id, &node.child_by_field_name("expression").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); @@ -981,7 +1019,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let expression = self.build_expression(id, &node.child_by_field_name("expression").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); @@ -1003,7 +1041,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let function = self.build_expression(id, &node.child_by_field_name("function").unwrap(), code); let mut argument_name_expression_map: Vec<(Option>, Expression)> = @@ -1075,7 +1113,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); let mut field_name_expression_map: Vec<(Rc, Expression)> = Vec::new(); let mut pending_name: Option> = None; @@ -1129,7 +1167,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let expression = self.build_expression(id, &node.child(1).unwrap(), code); let operator_node = node.child_by_field_name("operator").unwrap(); @@ -1158,7 +1196,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let expression = self.build_expression(id, &node.child(1).unwrap(), code); let node = Rc::new(AssertStatement::new(id, location, expression)); self.arena.add_node( @@ -1176,7 +1214,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let node = Rc::new(BreakStatement::new(id, location)); self.arena.add_node( AstNode::Statement(Statement::Break(node.clone())), @@ -1193,7 +1231,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let expression = self.build_expression(id, &node.child(1).unwrap(), code); let node = Rc::new(ParenthesizedExpression::new(id, location, expression)); @@ -1212,7 +1250,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let left = self.build_expression(id, &node.child_by_field_name("left").unwrap(), code); let operator_node = node.child_by_field_name("operator").unwrap(); let operator_kind = operator_node.kind(); @@ -1268,7 +1306,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let mut elements = Vec::new(); let mut cursor = node.walk(); for child in node.named_children(&mut cursor) { @@ -1291,7 +1329,7 @@ impl<'a> Builder<'a> { fn build_bool_literal(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let value = match node.utf8_text(code).unwrap() { "true" => true, "false" => false, @@ -1314,7 +1352,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let value = node.utf8_text(code).unwrap().to_string(); let node = Rc::new(StringLiteral::new(id, location, value)); self.arena.add_node( @@ -1332,7 +1370,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let value = node.utf8_text(code).unwrap().to_string(); let node = Rc::new(NumberLiteral::new(id, location, value)); self.arena.add_node( @@ -1345,7 +1383,7 @@ impl<'a> Builder<'a> { fn build_unit_literal(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let node = Rc::new(UnitLiteral::new(id, location)); self.arena.add_node( AstNode::Expression(Expression::Literal(Literal::Unit(node.clone()))), @@ -1383,7 +1421,7 @@ impl<'a> Builder<'a> { Type::Custom(name) } _ => { - let location = Self::get_location(node, code); + let location = self.get_location(node, code); panic!("Unexpected type: {node_kind}, {location}") } } @@ -1392,7 +1430,7 @@ impl<'a> Builder<'a> { fn build_type_array(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let element_type = self.build_type(id, &node.child_by_field_name("type").unwrap(), code); let length_node = node.child_by_field_name("length").unwrap(); let size = self.build_expression(id, &length_node, code); @@ -1408,7 +1446,7 @@ impl<'a> Builder<'a> { fn build_generic_type(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let base = self.build_identifier(id, &node.child_by_field_name("base_type").unwrap(), code); let args = node.child(1).unwrap(); @@ -1436,7 +1474,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let mut arguments = None; let mut cursor = node.walk(); let mut returns = None; @@ -1467,7 +1505,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let alias = self.build_identifier(id, &node.child_by_field_name("alias").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); @@ -1487,7 +1525,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let qualifier = self.build_identifier(id, &node.child_by_field_name("qualifier").unwrap(), code); let name = self.build_identifier(id, &node.child_by_field_name("name").unwrap(), code); @@ -1508,7 +1546,7 @@ impl<'a> Builder<'a> { ) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let node = Rc::new(UzumakiExpression::new(id, location)); self.arena.add_node( AstNode::Expression(Expression::Uzumaki(node.clone())), @@ -1520,7 +1558,7 @@ impl<'a> Builder<'a> { fn build_identifier(&mut self, parent_id: u32, node: &Node, code: &[u8]) -> Rc { self.collect_errors(node, code); let id = Self::get_node_id(); - let location = Self::get_location(node, code); + let location = self.get_location(node, code); let name = node.utf8_text(code).unwrap().to_string(); let node = Rc::new(Identifier::new(id, name, location)); self.arena.add_node( @@ -1540,15 +1578,31 @@ impl<'a> Builder<'a> { } #[allow(clippy::cast_possible_truncation)] - fn get_location(node: &Node, _code: &[u8]) -> Location { - let offset_start = node.start_byte() as u32; - let offset_end = node.end_byte() as u32; + pub(crate) fn get_location(&self, node: &Node, _code: &[u8]) -> Location { + let mut offset_start = node.start_byte() as u32; + let mut offset_end = node.end_byte() as u32; let start_position = node.start_position(); let end_position = node.end_position(); - let start_line = start_position.row as u32 + 1; - let start_column = start_position.column as u32 + 1; - let end_line = end_position.row as u32 + 1; - let end_column = end_position.column as u32 + 1; + let mut start_line = start_position.row as u32 + 1; + let mut start_column = start_position.column as u32 + 1; + let mut end_line = end_position.row as u32 + 1; + let mut end_column = end_position.column as u32 + 1; + + if self.location_base.offset != 0 + || self.location_base.line != 0 + || self.location_base.column != 0 + { + offset_start += self.location_base.offset; + offset_end += self.location_base.offset; + start_line += self.location_base.line; + end_line += self.location_base.line; + if start_position.row == 0 { + start_column += self.location_base.column; + } + if end_position.row == 0 { + end_column += self.location_base.column; + } + } Location { offset_start, @@ -1564,7 +1618,7 @@ impl<'a> Builder<'a> { let mut cursor = node.walk(); for child in node.children(&mut cursor) { if child.is_error() { - let location = Self::get_location(&child, code); + let location = self.get_location(&child, code); let source_snippet = String::from_utf8_lossy( &code[location.offset_start as usize..location.offset_end as usize], ); diff --git a/core/ast/src/lib.rs b/core/ast/src/lib.rs index 4902ce7..9f97ea2 100644 --- a/core/ast/src/lib.rs +++ b/core/ast/src/lib.rs @@ -31,7 +31,7 @@ //! - [`arena::Arena`] - Central storage for all AST nodes with O(1) lookups //! - [`builder::Builder`] - Builds AST from tree-sitter concrete syntax tree //! - [`nodes`] - AST node type definitions (`SourceFile`, `FunctionDefinition`, etc.) -//! - [`parser_context::ParserContext`] - Multi-file parsing context (WIP) +//! - [`parser_context::ParserContext`] - Multi-file parsing context //! //! # Key Features //! diff --git a/core/ast/src/nodes.rs b/core/ast/src/nodes.rs index aed18b5..7ab43b7 100644 --- a/core/ast/src/nodes.rs +++ b/core/ast/src/nodes.rs @@ -451,7 +451,7 @@ ast_nodes! { pub struct ModuleDefinition { pub visibility: Visibility, pub name: Rc, - pub body: Option>, + pub body: RefCell>>, } pub struct Argument { diff --git a/core/ast/src/nodes_impl.rs b/core/ast/src/nodes_impl.rs index 1cd6785..5edc758 100644 --- a/core/ast/src/nodes_impl.rs +++ b/core/ast/src/nodes_impl.rs @@ -443,7 +443,7 @@ impl ModuleDefinition { location, visibility, name, - body, + body: RefCell::new(body), } } diff --git a/core/ast/src/parser_context.rs b/core/ast/src/parser_context.rs index 89c4ea1..a58286a 100644 --- a/core/ast/src/parser_context.rs +++ b/core/ast/src/parser_context.rs @@ -5,8 +5,8 @@ //! //! # Status //! -//! **Work in Progress** - This module provides the skeleton for multi-file support -//! but is not yet functional. See CLAUDE.md: "Multi-file support not yet implemented." +//! Basic multi-file support is implemented by scanning for `mod` declarations and +//! parsing referenced files into a unified AST arena. //! //! # Planned Implementation //! @@ -22,7 +22,12 @@ use std::path::PathBuf; use std::rc::Rc; use crate::arena::Arena; -use crate::nodes::ModuleDefinition; +use crate::builder::{Builder, LocationBase}; +use crate::nodes::{ + Ast, AstNode, Definition, Directive, Expression, Identifier, Location, ModuleDefinition, + SourceFile, Visibility, +}; +use tree_sitter::Parser; /// Queue entry for pending file parsing. #[allow(dead_code)] @@ -31,6 +36,8 @@ struct ParseQueueEntry { scope_id: u32, /// Path to the source file. file_path: PathBuf, + /// Module declaration to populate when parsing external modules. + module: Option>, } /// Context for parsing multiple source files. @@ -58,6 +65,7 @@ impl ParserContext { queue: vec![ParseQueueEntry { scope_id: 0, file_path: root_path, + module: None, }], arena: Arena::default(), } @@ -69,9 +77,14 @@ impl ParserContext { /// /// Will add the file to the queue with its parent scope ID, enabling /// proper scope relationships when the file is parsed. - #[allow(clippy::unused_self)] - pub fn push_file(&mut self, _scope_id: u32, _file_path: PathBuf) { - // Not yet implemented - see module documentation + pub fn push_file( + &mut self, + scope_id: u32, + file_path: PathBuf, + module: Option>, + ) { + self.queue + .push(ParseQueueEntry { scope_id, file_path, module }); } /// Parses all queued files and builds the unified AST. @@ -90,9 +103,38 @@ impl ParserContext { /// } /// } /// ``` - #[must_use] - pub fn parse_all(&mut self) -> Arena { - std::mem::take(&mut self.arena) + pub fn parse_all(&mut self) -> anyhow::Result { + while let Some(entry) = self.queue.pop() { + let store_definitions = entry.module.is_none(); + let (file_arena, definitions) = + Self::parse_file(&entry.file_path, store_definitions)?; + + if let Some(module) = entry.module { + *module.body.borrow_mut() = Some(definitions.clone()); + } + + for definition in &definitions { + if let Definition::Module(module_definition) = definition { + self.process_module(module_definition, entry.scope_id, &entry.file_path); + } + } + + let Arena { + nodes, + parent_map, + children_map, + } = file_arena; + self.arena.nodes.extend(nodes); + self.arena.parent_map.extend(parent_map); + for (parent_id, children) in children_map { + self.arena + .children_map + .entry(parent_id) + .or_default() + .extend(children); + } + } + Ok(std::mem::take(&mut self.arena)) } /// Resolves and processes a module definition. @@ -118,11 +160,26 @@ impl ParserContext { #[allow(dead_code, clippy::unused_self)] fn process_module( &mut self, - _module: &Rc, + module: &Rc, _parent_scope_id: u32, - _current_file_path: &PathBuf, + current_file_path: &PathBuf, ) { - // Not yet implemented - see module documentation + let module_scope_id = self.next_node_id(); + + if module.body.borrow().is_none() { + if let Some(mod_path) = find_submodule_path(current_file_path, &module.name()) { + self.push_file(module_scope_id, mod_path, Some(Rc::clone(module))); + } + return; + } + + if let Some(body) = module.body.borrow().as_ref() { + for definition in body { + if let Definition::Module(child_module) = definition { + self.process_module(child_module, module_scope_id, current_file_path); + } + } + } } /// Generates a new unique node ID. @@ -132,6 +189,463 @@ impl ParserContext { self.next_id += 1; id } + + fn parse_file( + file_path: &PathBuf, + store_definitions: bool, + ) -> anyhow::Result<(Arena, Vec)> { + let source = std::fs::read_to_string(file_path)?; + let line_index = LineIndex::new(&source); + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_inference::language()) + .map_err(|_| anyhow::anyhow!("Error loading Inference grammar"))?; + let mut builder = Builder::new(); + let source_file_id = Builder::next_node_id(); + let location = location_from_offsets(&line_index, 0, source.len()); + let (definitions, directives) = parse_block_definitions( + &mut builder, + &mut parser, + &line_index, + &source, + 0, + source_file_id, + store_definitions, + )?; + + let mut source_file = SourceFile::new(source_file_id, location, source); + if store_definitions { + source_file.directives = directives; + source_file.definitions = definitions.clone(); + } + + builder.add_node( + AstNode::Ast(Ast::SourceFile(Rc::new(source_file))), + u32::MAX, + ); + + let errors = builder.take_errors(); + if !errors.is_empty() { + for err in errors { + eprintln!("AST Builder Error: {err}"); + } + return Err(anyhow::anyhow!("AST building failed due to errors")); + } + + Ok((builder.into_arena(), definitions)) + } +} + +#[derive(Clone, Copy)] +struct Span { + start: usize, + end: usize, +} + +struct ModuleDecl { + name: String, + visibility: Visibility, + span: Span, + name_span: Span, + body: Option, +} + +struct LineIndex { + starts: Vec, +} + +impl LineIndex { + fn new(source: &str) -> Self { + let mut starts = vec![0]; + for (idx, byte) in source.bytes().enumerate() { + if byte == b'\n' { + starts.push(idx + 1); + } + } + Self { starts } + } + + fn line_col(&self, offset: usize) -> (u32, u32) { + let line_idx = match self.starts.binary_search(&offset) { + Ok(idx) => idx, + Err(idx) => idx.saturating_sub(1), + }; + let line_start = self.starts.get(line_idx).copied().unwrap_or(0); + let line = line_idx as u32 + 1; + let column = (offset - line_start) as u32 + 1; + (line, column) + } +} + +fn location_from_offsets(line_index: &LineIndex, start: usize, end: usize) -> Location { + let (start_line, start_column) = line_index.line_col(start); + let (end_line, end_column) = line_index.line_col(end); + Location::new( + start as u32, + end as u32, + start_line, + start_column, + end_line, + end_column, + ) +} + +fn location_base(line_index: &LineIndex, offset: usize) -> LocationBase { + let (line, column) = line_index.line_col(offset); + LocationBase { + offset: offset as u32, + line: line.saturating_sub(1), + column: column.saturating_sub(1), + } +} + +fn parse_block_definitions( + builder: &mut Builder, + parser: &mut Parser, + line_index: &LineIndex, + source: &str, + base_offset: usize, + parent_id: u32, + include_directives: bool, +) -> anyhow::Result<(Vec, Vec)> { + let (modules, sanitized_source) = scan_modules(source); + let tree = parser + .parse(&sanitized_source, None) + .ok_or_else(|| anyhow::anyhow!("Parse error"))?; + let root = tree.root_node(); + + let base = location_base(line_index, base_offset); + builder.set_location_base(base); + + let mut directives: Vec = Vec::new(); + let mut definitions = Vec::new(); + let mut cursor = root.walk(); + for child in root.children(&mut cursor) { + match child.kind() { + "use_directive" if include_directives => { + directives.push(Directive::Use(builder.build_use_directive( + parent_id, + &child, + sanitized_source.as_bytes(), + ))); + } + "use_directive" => {} + _ => { + let definition = builder.build_definition( + parent_id, + &child, + sanitized_source.as_bytes(), + ); + definitions.push((definition.location().offset_start, definition)); + } + } + } + + builder.reset_location_base(); + + for module in modules { + let module_def = build_module_definition( + builder, + parser, + line_index, + source, + base_offset, + parent_id, + module, + )?; + let offset = module_def.location.offset_start; + definitions.push((offset, Definition::Module(module_def))); + } + + definitions.sort_by_key(|(offset, _)| *offset); + let definitions = definitions + .into_iter() + .map(|(_, definition)| definition) + .collect(); + + Ok((definitions, directives)) +} + +fn build_module_definition( + builder: &mut Builder, + parser: &mut Parser, + line_index: &LineIndex, + source: &str, + base_offset: usize, + parent_id: u32, + module: ModuleDecl, +) -> anyhow::Result> { + let ModuleDecl { + name, + visibility, + span, + name_span, + body, + } = module; + + let module_id = Builder::next_node_id(); + let name_id = Builder::next_node_id(); + + let name_start = base_offset + name_span.start; + let name_end = base_offset + name_span.end; + let name_location = location_from_offsets(line_index, name_start, name_end); + let name_node = Rc::new(Identifier::new(name_id, name, name_location)); + builder.add_node( + AstNode::Expression(Expression::Identifier(name_node.clone())), + module_id, + ); + + let module_start = base_offset + span.start; + let module_end = base_offset + span.end; + let module_location = location_from_offsets(line_index, module_start, module_end); + + let body = if let Some(body_span) = body { + let body_source = &source[body_span.start..body_span.end]; + let (body_defs, _) = parse_block_definitions( + builder, + parser, + line_index, + body_source, + base_offset + body_span.start, + module_id, + false, + )?; + Some(body_defs) + } else { + None + }; + + let module_def = Rc::new(ModuleDefinition::new( + module_id, + visibility, + name_node, + body, + module_location, + )); + + builder.add_node( + AstNode::Definition(Definition::Module(module_def.clone())), + parent_id, + ); + + Ok(module_def) +} + +fn scan_modules(source: &str) -> (Vec, String) { + let bytes = source.as_bytes(); + let len = bytes.len(); + let mut modules = Vec::new(); + let mut i = 0; + let mut depth = 0u32; + + while i < len { + if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'/' { + i = skip_line_comment(bytes, i + 2); + continue; + } + if bytes[i] == b'"' { + i = skip_string(bytes, i + 1); + continue; + } + match bytes[i] { + b'{' => { + depth += 1; + i += 1; + continue; + } + b'}' => { + depth = depth.saturating_sub(1); + i += 1; + continue; + } + _ => {} + } + + if depth == 0 && is_ident_start(bytes[i]) { + let (ident, ident_start, ident_end) = parse_ident(bytes, i); + if ident == "pub" { + let j = skip_ws_and_comments(bytes, ident_end); + if j < len && is_ident_start(bytes[j]) { + let (next_ident, _mod_start, mod_end) = parse_ident(bytes, j); + if next_ident == "mod" { + if let Some((module, next_idx)) = + parse_module_decl(bytes, ident_start, mod_end, Visibility::Public) + { + modules.push(module); + i = next_idx; + continue; + } + } + } + } else if ident == "mod" { + if let Some((module, next_idx)) = + parse_module_decl(bytes, ident_start, ident_end, Visibility::Private) + { + modules.push(module); + i = next_idx; + continue; + } + } + i = ident_end; + continue; + } + + i += 1; + } + + let mut sanitized = bytes.to_vec(); + for module in &modules { + for idx in module.span.start..module.span.end { + let byte = sanitized[idx]; + if byte != b'\n' && byte != b'\r' { + sanitized[idx] = b' '; + } + } + } + + let sanitized = String::from_utf8_lossy(&sanitized).into_owned(); + (modules, sanitized) +} + +fn parse_module_decl( + bytes: &[u8], + decl_start: usize, + mod_end: usize, + visibility: Visibility, +) -> Option<(ModuleDecl, usize)> { + let len = bytes.len(); + let mut i = skip_ws_and_comments(bytes, mod_end); + if i >= len || !is_ident_start(bytes[i]) { + return None; + } + let (name, name_start, name_end) = parse_ident(bytes, i); + i = skip_ws_and_comments(bytes, name_end); + if i >= len { + return None; + } + if bytes[i] == b';' { + let span = Span { + start: decl_start, + end: i + 1, + }; + let module = ModuleDecl { + name, + visibility, + span, + name_span: Span { + start: name_start, + end: name_end, + }, + body: None, + }; + return Some((module, i + 1)); + } + if bytes[i] == b'{' { + let body_start = i + 1; + let body_end = find_matching_brace(bytes, body_start)?; + let span = Span { + start: decl_start, + end: body_end + 1, + }; + let module = ModuleDecl { + name, + visibility, + span, + name_span: Span { + start: name_start, + end: name_end, + }, + body: Some(Span { + start: body_start, + end: body_end, + }), + }; + return Some((module, body_end + 1)); + } + None +} + +fn find_matching_brace(bytes: &[u8], mut i: usize) -> Option { + let len = bytes.len(); + let mut depth = 1u32; + while i < len { + if bytes[i] == b'/' && i + 1 < len && bytes[i + 1] == b'/' { + i = skip_line_comment(bytes, i + 2); + continue; + } + if bytes[i] == b'"' { + i = skip_string(bytes, i + 1); + continue; + } + match bytes[i] { + b'{' => depth += 1, + b'}' => { + depth = depth.saturating_sub(1); + if depth == 0 { + return Some(i); + } + } + _ => {} + } + i += 1; + } + None +} + +fn skip_line_comment(bytes: &[u8], mut i: usize) -> usize { + let len = bytes.len(); + while i < len && bytes[i] != b'\n' { + i += 1; + } + i +} + +fn skip_string(bytes: &[u8], mut i: usize) -> usize { + let len = bytes.len(); + while i < len { + match bytes[i] { + b'\\' if i + 1 < len => { + i += 2; + } + b'"' => { + return i + 1; + } + _ => i += 1, + } + } + i +} + +fn skip_ws_and_comments(bytes: &[u8], mut i: usize) -> usize { + let len = bytes.len(); + while i < len { + match bytes[i] { + b' ' | b'\t' | b'\n' | b'\r' => i += 1, + b'/' if i + 1 < len && bytes[i + 1] == b'/' => { + i = skip_line_comment(bytes, i + 2); + } + _ => break, + } + } + i +} + +fn parse_ident(bytes: &[u8], start: usize) -> (String, usize, usize) { + let len = bytes.len(); + let mut i = start; + while i < len && is_ident_continue(bytes[i]) { + i += 1; + } + let name = String::from_utf8_lossy(&bytes[start..i]).into_owned(); + (name, start, i) +} + +fn is_ident_start(byte: u8) -> bool { + byte == b'_' || (byte as char).is_ascii_alphabetic() +} + +fn is_ident_continue(byte: u8) -> bool { + is_ident_start(byte) || (byte as char).is_ascii_digit() } /// Finds the path to a submodule file. @@ -142,8 +656,17 @@ impl ParserContext { /// 1. `{current_dir}/{module_name}.inf` /// 2. `{current_dir}/{module_name}/mod.inf` /// -/// Returns `None` until multi-file support is implemented. +/// Returns the first path that exists, or `None` if no candidate is found. #[must_use] -pub fn find_submodule_path(_current_file: &PathBuf, _module_name: &str) -> Option { +pub fn find_submodule_path(current_file: &PathBuf, module_name: &str) -> Option { + let current_dir = current_file.parent()?; + let file_candidate = current_dir.join(format!("{module_name}.inf")); + if file_candidate.exists() { + return Some(file_candidate); + } + let mod_candidate = current_dir.join(module_name).join("mod.inf"); + if mod_candidate.exists() { + return Some(mod_candidate); + } None } diff --git a/core/cli/src/main.rs b/core/cli/src/main.rs index 3d36c10..9947553 100644 --- a/core/cli/src/main.rs +++ b/core/cli/src/main.rs @@ -95,7 +95,7 @@ //! //! ## Current Limitations //! -//! - Single-file compilation only (multi-file projects not yet supported) +//! - Module path resolution for nested modules is best-effort //! - Output directory is relative to CWD, not source file location //! - Analysis phase is work-in-progress //! @@ -109,7 +109,7 @@ mod parser; use clap::Parser; -use inference::{analyze, codegen, parse, type_check, wasm_to_v}; +use inference::{analyze, codegen, parse_file, type_check, wasm_to_v}; use parser::Cli; use std::{ fs, @@ -182,10 +182,9 @@ fn main() { process::exit(1); } - let source_code = fs::read_to_string(&args.path).expect("Error reading source file"); let mut t_ast = None; if need_codegen || need_analyze || need_parse { - match parse(source_code.as_str()) { + match parse_file(&args.path) { Ok(ast) => { println!("Parsed: {}", args.path.display()); t_ast = Some(ast); diff --git a/core/cli/src/parser.rs b/core/cli/src/parser.rs index a9c0f61..e42916b 100644 --- a/core/cli/src/parser.rs +++ b/core/cli/src/parser.rs @@ -53,8 +53,8 @@ Parse builds the typed AST; analyze performs semantic/type inference; codegen em pub(crate) struct Cli { /// Path to the source file to compile. /// - /// Currently only single-file compilation is supported. Multi-file projects - /// and project file (`.infp`) support is planned for future releases. + /// Multi-file projects are supported via `mod name;` declarations that + /// resolve submodule files relative to this path. pub(crate) path: std::path::PathBuf, /// Run the parse phase to build the typed AST. diff --git a/core/inference/src/lib.rs b/core/inference/src/lib.rs index 9527eb2..0028a44 100644 --- a/core/inference/src/lib.rs +++ b/core/inference/src/lib.rs @@ -140,8 +140,7 @@ //! //! ## Current Limitations //! -//! - **Single-file support**: Multi-file compilation is not yet implemented. -//! The AST expects a single source file as input. +//! - **Multi-file support**: Submodules are resolved via `mod name;` declarations. //! - **Analyze phase**: The semantic analysis phase is work-in-progress and //! currently returns `Ok(())` without performing any checks. //! @@ -152,7 +151,8 @@ //! - [`inference_type_checker::TypeCheckerBuilder`] - Type checking entry point //! - [`inference_type_checker::typed_context::TypedContext`] - Type information storage -use inference_ast::{arena::Arena, builder::Builder}; +use inference_ast::{arena::Arena, builder::Builder, parser_context::ParserContext}; +use std::path::Path; use inference_type_checker::typed_context::TypedContext; /// Parses source code and builds an arena-based Abstract Syntax Tree. @@ -214,6 +214,15 @@ pub fn parse(source_code: &str) -> anyhow::Result { Ok(arena) } +/// Parses a source file and any `mod`-referenced submodules into a unified AST arena. +/// +/// This entry point enables multi-file projects by following `mod name;` declarations +/// and parsing submodule files on disk. +pub fn parse_file(path: &Path) -> anyhow::Result { + let mut context = ParserContext::new(path.to_path_buf()); + context.parse_all() +} + /// Performs bidirectional type checking and inference on the AST. /// /// This function analyzes the AST to build a complete type mapping for all diff --git a/core/type-checker/src/symbol_table.rs b/core/type-checker/src/symbol_table.rs index 132508f..09658ba 100644 --- a/core/type-checker/src/symbol_table.rs +++ b/core/type-checker/src/symbol_table.rs @@ -745,6 +745,15 @@ impl SymbolTable { self.current_scope.as_ref().map(|s| s.borrow().id) } + pub(crate) fn enter_scope(&mut self, scope_id: u32) -> bool { + if let Some(scope) = self.scopes.get(&scope_id) { + self.current_scope = Some(Rc::clone(scope)); + true + } else { + false + } + } + #[must_use = "this is a pure lookup with no side effects"] pub(crate) fn get_scope(&self, scope_id: u32) -> Option { self.scopes.get(&scope_id).cloned() @@ -859,8 +868,7 @@ impl SymbolTable { /// Register a definition from an external module into the current scope. /// - /// Currently handles: Struct, Enum, Spec, Function, Type. - /// Skips: Constant, ExternalFunction, Module (deferred to future phases). + /// Currently handles: Struct, Enum, Spec, Function, Type, Constant, ExternalFunction, Module. #[allow(dead_code)] fn register_definition_from_external(&mut self, definition: &Definition) -> anyhow::Result<()> { match definition { @@ -920,7 +928,50 @@ impl SymbolTable { Definition::Type(t) => { self.register_type(&t.name(), Some(&t.ty))?; } - Definition::Constant(_) | Definition::ExternalFunction(_) | Definition::Module(_) => {} + Definition::Constant(constant_definition) => { + self.push_variable_to_scope( + &constant_definition.name(), + TypeInfo::new(&constant_definition.ty), + )?; + } + Definition::ExternalFunction(external_function_definition) => { + let param_types: Vec<_> = external_function_definition + .arguments + .as_ref() + .unwrap_or(&vec![]) + .iter() + .filter_map(|param| match param { + ArgumentType::SelfReference(_) => None, + ArgumentType::IgnoreArgument(ignore_argument) => { + Some(ignore_argument.ty.clone()) + } + ArgumentType::Argument(argument) => Some(argument.ty.clone()), + ArgumentType::Type(ty) => Some(ty.clone()), + }) + .collect(); + let return_type = external_function_definition + .returns + .clone() + .unwrap_or(Type::Simple(SimpleTypeKind::Unit)); + + self.register_function_with_visibility( + &external_function_definition.name(), + vec![], + ¶m_types, + &return_type, + external_function_definition.visibility.clone(), + ) + .map_err(|e| anyhow::anyhow!(e))?; + } + Definition::Module(module_definition) => { + let _scope_id = self.enter_module(module_definition); + if let Some(body) = module_definition.body.borrow().as_ref() { + for definition in body { + self.register_definition_from_external(definition)?; + } + } + self.pop_scope(); + } } Ok(()) } diff --git a/core/type-checker/src/type_checker.rs b/core/type-checker/src/type_checker.rs index 82de1b2..fb62ba6 100644 --- a/core/type-checker/src/type_checker.rs +++ b/core/type-checker/src/type_checker.rs @@ -76,23 +76,7 @@ impl TypeChecker { // Continue to inference phase even if registration had errors // to collect all errors before returning for source_file in ctx.source_files() { - for def in &source_file.definitions { - match def { - Definition::Function(function_definition) => { - self.infer_variables(function_definition.clone(), ctx); - } - Definition::Struct(struct_definition) => { - let struct_type = TypeInfo { - kind: TypeInfoKind::Struct(struct_definition.name()), - type_params: vec![], - }; - for method in &struct_definition.methods { - self.infer_method_variables(method.clone(), struct_type.clone(), ctx); - } - } - _ => {} - } - } + self.infer_definitions(&source_file.definitions, ctx, &[]); } if !self.errors.is_empty() { let error_messages: Vec = std::mem::take(&mut self.errors) @@ -107,150 +91,159 @@ impl TypeChecker { /// Registers `Definition::Type`, `Definition::Struct`, `Definition::Enum`, and `Definition::Spec` fn register_types(&mut self, ctx: &mut TypedContext) { for source_file in ctx.source_files() { - for definition in &source_file.definitions { - match definition { - Definition::Type(type_definition) => { - self.symbol_table - .register_type(&type_definition.name(), Some(&type_definition.ty)) - .unwrap_or_else(|_| { - self.errors.push(TypeCheckError::RegistrationFailed { - kind: RegistrationKind::Type, - name: type_definition.name(), - reason: None, - location: type_definition.location, - }); + self.register_types_in_definitions(&source_file.definitions); + } + } + + fn register_types_in_definitions(&mut self, definitions: &[Definition]) { + for definition in definitions { + match definition { + Definition::Type(type_definition) => { + self.symbol_table + .register_type(&type_definition.name(), Some(&type_definition.ty)) + .unwrap_or_else(|_| { + self.errors.push(TypeCheckError::RegistrationFailed { + kind: RegistrationKind::Type, + name: type_definition.name(), + reason: None, + location: type_definition.location, }); - } - Definition::Struct(struct_definition) => { - let fields: Vec<(String, TypeInfo, Visibility)> = struct_definition - .fields - .iter() - .map(|f| { - ( - f.name.name.clone(), - TypeInfo::new(&f.type_), - Visibility::Private, - ) - }) - .collect(); - self.symbol_table - .register_struct( - &struct_definition.name(), - &fields, - vec![], - struct_definition.visibility.clone(), + }); + } + Definition::Struct(struct_definition) => { + let fields: Vec<(String, TypeInfo, Visibility)> = struct_definition + .fields + .iter() + .map(|f| { + ( + f.name.name.clone(), + TypeInfo::new(&f.type_), + Visibility::Private, ) - .unwrap_or_else(|_| { - self.errors.push(TypeCheckError::RegistrationFailed { - kind: RegistrationKind::Struct, - name: struct_definition.name(), - reason: None, - location: struct_definition.location, - }); - }); - - let struct_name = struct_definition.name(); - for method in &struct_definition.methods { - let has_self = method.arguments.as_ref().is_some_and(|args| { - args.iter() - .any(|arg| matches!(arg, ArgumentType::SelfReference(_))) + }) + .collect(); + self.symbol_table + .register_struct( + &struct_definition.name(), + &fields, + vec![], + struct_definition.visibility.clone(), + ) + .unwrap_or_else(|_| { + self.errors.push(TypeCheckError::RegistrationFailed { + kind: RegistrationKind::Struct, + name: struct_definition.name(), + reason: None, + location: struct_definition.location, }); + }); - let param_types: Vec = method - .arguments - .as_ref() - .unwrap_or(&vec![]) - .iter() - .filter_map(|param| match param { - ArgumentType::SelfReference(_) => None, - ArgumentType::IgnoreArgument(ignore_arg) => { - Some(TypeInfo::new(&ignore_arg.ty)) - } - ArgumentType::Argument(arg) => Some(TypeInfo::new(&arg.ty)), - ArgumentType::Type(ty) => Some(TypeInfo::new(ty)), - }) - .collect(); + let struct_name = struct_definition.name(); + for method in &struct_definition.methods { + let has_self = method.arguments.as_ref().is_some_and(|args| { + args.iter() + .any(|arg| matches!(arg, ArgumentType::SelfReference(_))) + }); - let return_type = method - .returns - .as_ref() - .map(TypeInfo::new) - .unwrap_or_default(); + let param_types: Vec = method + .arguments + .as_ref() + .unwrap_or(&vec![]) + .iter() + .filter_map(|param| match param { + ArgumentType::SelfReference(_) => None, + ArgumentType::IgnoreArgument(ignore_arg) => { + Some(TypeInfo::new(&ignore_arg.ty)) + } + ArgumentType::Argument(arg) => Some(TypeInfo::new(&arg.ty)), + ArgumentType::Type(ty) => Some(TypeInfo::new(ty)), + }) + .collect(); - let type_params: Vec = method - .type_parameters - .as_ref() - .unwrap_or(&vec![]) - .iter() - .map(|p| p.name()) - .collect(); - - let definition_scope_id = - self.symbol_table.current_scope_id().unwrap_or(0); - let signature = FuncInfo { - name: method.name(), - type_params, - param_types, - return_type, - visibility: method.visibility.clone(), - definition_scope_id, - }; + let return_type = method + .returns + .as_ref() + .map(TypeInfo::new) + .unwrap_or_default(); - self.symbol_table - .register_method( - &struct_name, - signature, - method.visibility.clone(), - has_self, - ) - .unwrap_or_else(|err| { - self.errors.push(TypeCheckError::RegistrationFailed { - kind: RegistrationKind::Method, - name: format!("{struct_name}::{}", method.name()), - reason: Some(err.to_string()), - location: method.location, - }); - }); - } - } - Definition::Enum(enum_definition) => { - let variants: Vec<&str> = enum_definition - .variants + let type_params: Vec = method + .type_parameters + .as_ref() + .unwrap_or(&vec![]) .iter() - .map(|v| v.name.as_str()) + .map(|p| p.name()) .collect(); + + let definition_scope_id = self.symbol_table.current_scope_id().unwrap_or(0); + let signature = FuncInfo { + name: method.name(), + type_params, + param_types, + return_type, + visibility: method.visibility.clone(), + definition_scope_id, + }; + self.symbol_table - .register_enum( - &enum_definition.name(), - &variants, - enum_definition.visibility.clone(), + .register_method( + &struct_name, + signature, + method.visibility.clone(), + has_self, ) - .unwrap_or_else(|_| { + .unwrap_or_else(|err| { self.errors.push(TypeCheckError::RegistrationFailed { - kind: RegistrationKind::Enum, - name: enum_definition.name(), - reason: None, - location: enum_definition.location, + kind: RegistrationKind::Method, + name: format!("{struct_name}::{}", method.name()), + reason: Some(err.to_string()), + location: method.location, }); }); } - Definition::Spec(spec_definition) => { - self.symbol_table - .register_spec(&spec_definition.name()) - .unwrap_or_else(|_| { - self.errors.push(TypeCheckError::RegistrationFailed { - kind: RegistrationKind::Spec, - name: spec_definition.name(), - reason: None, - location: spec_definition.location, - }); + } + Definition::Enum(enum_definition) => { + let variants: Vec<&str> = enum_definition + .variants + .iter() + .map(|v| v.name.as_str()) + .collect(); + self.symbol_table + .register_enum( + &enum_definition.name(), + &variants, + enum_definition.visibility.clone(), + ) + .unwrap_or_else(|_| { + self.errors.push(TypeCheckError::RegistrationFailed { + kind: RegistrationKind::Enum, + name: enum_definition.name(), + reason: None, + location: enum_definition.location, + }); + }); + } + Definition::Spec(spec_definition) => { + self.symbol_table + .register_spec(&spec_definition.name()) + .unwrap_or_else(|_| { + self.errors.push(TypeCheckError::RegistrationFailed { + kind: RegistrationKind::Spec, + name: spec_definition.name(), + reason: None, + location: spec_definition.location, }); + }); + } + Definition::Module(module_definition) => { + let _scope_id = self.symbol_table.enter_module(module_definition); + if let Some(body) = module_definition.body.borrow().as_ref() { + self.register_types_in_definitions(body); } - Definition::Constant(_) - | Definition::Function(_) - | Definition::ExternalFunction(_) - | Definition::Module(_) => {} + self.symbol_table.pop_scope(); } + Definition::Constant(_) + | Definition::Function(_) + | Definition::ExternalFunction(_) => {} } } } @@ -259,153 +252,208 @@ impl TypeChecker { #[allow(clippy::too_many_lines)] fn collect_function_and_constant_definitions(&mut self, ctx: &mut TypedContext) { for sf in ctx.source_files() { - for definition in &sf.definitions { - match definition { - Definition::Constant(constant_definition) => { - let const_type = TypeInfo::new(&constant_definition.ty); - if let Err(err) = self - .symbol_table - .push_variable_to_scope(&constant_definition.name(), const_type.clone()) - { - self.errors.push(TypeCheckError::RegistrationFailed { - kind: RegistrationKind::Variable, - name: constant_definition.name(), - reason: Some(err.to_string()), - location: constant_definition.location, - }); + self.collect_function_and_constant_definitions_in_definitions( + &sf.definitions, + ctx, + &[], + ); + } + } + + #[allow(clippy::too_many_lines)] + fn collect_function_and_constant_definitions_in_definitions( + &mut self, + definitions: &[Definition], + ctx: &mut TypedContext, + path: &[String], + ) { + for definition in definitions { + match definition { + Definition::Constant(constant_definition) => { + let const_type = TypeInfo::new(&constant_definition.ty); + if let Err(err) = self + .symbol_table + .push_variable_to_scope(&constant_definition.name(), const_type.clone()) + { + self.errors.push(TypeCheckError::RegistrationFailed { + kind: RegistrationKind::Variable, + name: constant_definition.name(), + reason: Some(err.to_string()), + location: constant_definition.location, + }); + } + ctx.set_node_typeinfo(constant_definition.value.id(), const_type); + } + Definition::Function(function_definition) => { + for param in function_definition.arguments.as_ref().unwrap_or(&vec![]) { + match param { + ArgumentType::SelfReference(self_ref) => { + self.errors.push(TypeCheckError::SelfReferenceInFunction { + function_name: function_definition.name(), + location: self_ref.location, + }); + } + ArgumentType::IgnoreArgument(ignore_argument) => { + self.validate_type( + &ignore_argument.ty, + function_definition.type_parameters.as_ref(), + ); + ctx.set_node_typeinfo( + ignore_argument.id, + TypeInfo::new(&ignore_argument.ty), + ); + } + ArgumentType::Argument(arg) => { + self.validate_type( + &arg.ty, + function_definition.type_parameters.as_ref(), + ); + let type_info = TypeInfo::new(&arg.ty); + ctx.set_node_typeinfo(arg.id, type_info.clone()); + ctx.set_node_typeinfo(arg.name.id, type_info); + } + ArgumentType::Type(ty) => { + self.validate_type(ty, function_definition.type_parameters.as_ref()); + } } - ctx.set_node_typeinfo(constant_definition.value.id(), const_type); } - Definition::Function(function_definition) => { - for param in function_definition.arguments.as_ref().unwrap_or(&vec![]) { - match param { - ArgumentType::SelfReference(self_ref) => { - self.errors.push(TypeCheckError::SelfReferenceInFunction { - function_name: function_definition.name(), - location: self_ref.location, - }); - } + ctx.set_node_typeinfo( + function_definition.name.id, + TypeInfo { + kind: TypeInfoKind::Function(function_definition.name()), + type_params: function_definition + .type_parameters + .as_ref() + .map_or(vec![], |p| p.iter().map(|i| i.name.clone()).collect()), + }, + ); + if let Some(return_type) = &function_definition.returns { + self.validate_type(return_type, function_definition.type_parameters.as_ref()); + ctx.set_node_typeinfo(return_type.id(), TypeInfo::new(return_type)); + } + if let Err(err) = self.symbol_table.register_function( + &function_definition.name(), + function_definition + .type_parameters + .as_ref() + .unwrap_or(&vec![]) + .iter() + .map(|param| param.name()) + .collect::>(), + &function_definition + .arguments + .as_ref() + .unwrap_or(&vec![]) + .iter() + .filter_map(|param| match param { + ArgumentType::SelfReference(_) => None, ArgumentType::IgnoreArgument(ignore_argument) => { - self.validate_type( - &ignore_argument.ty, - function_definition.type_parameters.as_ref(), - ); - ctx.set_node_typeinfo( - ignore_argument.id, - TypeInfo::new(&ignore_argument.ty), - ); - } - ArgumentType::Argument(arg) => { - self.validate_type( - &arg.ty, - function_definition.type_parameters.as_ref(), - ); - let type_info = TypeInfo::new(&arg.ty); - ctx.set_node_typeinfo(arg.id, type_info.clone()); - ctx.set_node_typeinfo(arg.name.id, type_info); + Some(ignore_argument.ty.clone()) } - ArgumentType::Type(ty) => { - self.validate_type( - ty, - function_definition.type_parameters.as_ref(), - ); + ArgumentType::Argument(argument) => Some(argument.ty.clone()), + ArgumentType::Type(ty) => Some(ty.clone()), + }) + .collect::>(), + &function_definition + .returns + .as_ref() + .unwrap_or(&Type::Simple(SimpleTypeKind::Unit)) + .clone(), + ) { + self.errors.push(TypeCheckError::RegistrationFailed { + kind: RegistrationKind::Function, + name: function_definition.name(), + reason: Some(err), + location: function_definition.location, + }); + } + } + Definition::ExternalFunction(external_function_definition) => { + if let Err(err) = self.symbol_table.register_function( + &external_function_definition.name(), + vec![], + &external_function_definition + .arguments + .as_ref() + .unwrap_or(&vec![]) + .iter() + .filter_map(|param| match param { + ArgumentType::SelfReference(_) => None, + ArgumentType::IgnoreArgument(ignore_argument) => { + Some(ignore_argument.ty.clone()) } - } - } - ctx.set_node_typeinfo( - function_definition.name.id, - TypeInfo { - kind: TypeInfoKind::Function(function_definition.name()), - type_params: function_definition - .type_parameters - .as_ref() - .map_or(vec![], |p| p.iter().map(|i| i.name.clone()).collect()), - }, + ArgumentType::Argument(argument) => Some(argument.ty.clone()), + ArgumentType::Type(ty) => Some(ty.clone()), + }) + .collect::>(), + &external_function_definition + .returns + .as_ref() + .unwrap_or(&Type::Simple(SimpleTypeKind::Unit)) + .clone(), + ) { + self.errors.push(TypeCheckError::RegistrationFailed { + kind: RegistrationKind::Function, + name: external_function_definition.name(), + reason: Some(err), + location: external_function_definition.location, + }); + } + } + Definition::Module(module_definition) => { + let mut next_path = path.to_vec(); + next_path.push(module_definition.name()); + if let Some(scope_id) = self.symbol_table.find_module_scope(&next_path) { + self.symbol_table.enter_scope(scope_id); + } else { + let _ = self.symbol_table.enter_module(module_definition); + } + if let Some(body) = module_definition.body.borrow().as_ref() { + self.collect_function_and_constant_definitions_in_definitions( + body, + ctx, + &next_path, ); - if let Some(return_type) = &function_definition.returns { - self.validate_type( - return_type, - function_definition.type_parameters.as_ref(), - ); - ctx.set_node_typeinfo(return_type.id(), TypeInfo::new(return_type)); - } - // Register function even if parameter validation had errors - // to allow error recovery and prevent spurious UndefinedFunction errors - if let Err(err) = self.symbol_table.register_function( - &function_definition.name(), - function_definition - .type_parameters - .as_ref() - .unwrap_or(&vec![]) - .iter() - .map(|param| param.name()) - .collect::>(), - &function_definition - .arguments - .as_ref() - .unwrap_or(&vec![]) - .iter() - .filter_map(|param| match param { - ArgumentType::SelfReference(_) => None, - ArgumentType::IgnoreArgument(ignore_argument) => { - Some(ignore_argument.ty.clone()) - } - ArgumentType::Argument(argument) => Some(argument.ty.clone()), - ArgumentType::Type(ty) => Some(ty.clone()), - }) - .collect::>(), - &function_definition - .returns - .as_ref() - .unwrap_or(&Type::Simple(SimpleTypeKind::Unit)) - .clone(), - ) { - self.errors.push(TypeCheckError::RegistrationFailed { - kind: RegistrationKind::Function, - name: function_definition.name(), - reason: Some(err), - location: function_definition.location, - }); - } } - Definition::ExternalFunction(external_function_definition) => { - if let Err(err) = self.symbol_table.register_function( - &external_function_definition.name(), - vec![], - &external_function_definition - .arguments - .as_ref() - .unwrap_or(&vec![]) - .iter() - .filter_map(|param| match param { - ArgumentType::SelfReference(_) => None, - ArgumentType::IgnoreArgument(ignore_argument) => { - Some(ignore_argument.ty.clone()) - } - ArgumentType::Argument(argument) => Some(argument.ty.clone()), - ArgumentType::Type(ty) => Some(ty.clone()), - }) - .collect::>(), - &external_function_definition - .returns - .as_ref() - .unwrap_or(&Type::Simple(SimpleTypeKind::Unit)) - .clone(), - ) { - self.errors.push(TypeCheckError::RegistrationFailed { - kind: RegistrationKind::Function, - name: external_function_definition.name(), - reason: Some(err), - location: external_function_definition.location, - }); - } + self.symbol_table.pop_scope(); + } + Definition::Spec(_) + | Definition::Struct(_) + | Definition::Enum(_) + | Definition::Type(_) => {} + } + } + } + + fn infer_definitions(&mut self, definitions: &[Definition], ctx: &mut TypedContext, path: &[String]) { + for definition in definitions { + match definition { + Definition::Function(function_definition) => { + self.infer_variables(function_definition.clone(), ctx); + } + Definition::Struct(struct_definition) => { + let struct_type = TypeInfo { + kind: TypeInfoKind::Struct(struct_definition.name()), + type_params: vec![], + }; + for method in &struct_definition.methods { + self.infer_method_variables(method.clone(), struct_type.clone(), ctx); + } + } + Definition::Module(module_definition) => { + let mut next_path = path.to_vec(); + next_path.push(module_definition.name()); + if let Some(scope_id) = self.symbol_table.find_module_scope(&next_path) { + self.symbol_table.enter_scope(scope_id); + } else { + let _ = self.symbol_table.enter_module(module_definition); } - Definition::Spec(_) - | Definition::Struct(_) - | Definition::Enum(_) - | Definition::Type(_) - | Definition::Module(_) => {} + if let Some(body) = module_definition.body.borrow().as_ref() { + self.infer_definitions(body, ctx, &next_path); + } + self.symbol_table.pop_scope(); } + _ => {} } } } @@ -1607,7 +1655,7 @@ impl TypeChecker { ) -> anyhow::Result<()> { let _scope_id = self.symbol_table.enter_module(module); - if let Some(body) = &module.body { + if let Some(body) = module.body.borrow().as_ref() { for definition in body { match definition { Definition::Type(type_definition) => { diff --git a/core/wasm-codegen/src/lib.rs b/core/wasm-codegen/src/lib.rs index b296e07..7745389 100644 --- a/core/wasm-codegen/src/lib.rs +++ b/core/wasm-codegen/src/lib.rs @@ -75,6 +75,7 @@ #![warn(clippy::pedantic)] +use inference_ast::nodes::Definition; use inference_type_checker::typed_context::TypedContext; use inkwell::{ context::Context, @@ -90,8 +91,7 @@ mod utils; /// /// # Errors /// -/// Returns an error if more than one source file is present in the AST, as multi-file -/// support is not yet implemented. +/// Supports multiple source files by traversing all parsed modules. /// /// Returns an error if code generation fails. pub fn codegen(typed_context: &TypedContext) -> anyhow::Result> { @@ -102,10 +102,6 @@ pub fn codegen(typed_context: &TypedContext) -> anyhow::Result> { if typed_context.source_files().is_empty() { return compiler.compile_to_wasm("output.wasm", 3); } - if typed_context.source_files().len() > 1 { - todo!("Multi-file support not yet implemented"); - } - traverse_t_ast_with_compiler(typed_context, &compiler); let wasm_bytes = compiler.compile_to_wasm("output.wasm", 3)?; Ok(wasm_bytes) @@ -127,11 +123,29 @@ pub fn codegen(typed_context: &TypedContext) -> anyhow::Result> { /// /// - Only function definitions are compiled /// - Type definitions, constants, and other top-level items are ignored -/// - Multi-file compilation is not fully tested (see `codegen` function) +/// - Module name mangling is not yet implemented for nested functions fn traverse_t_ast_with_compiler(typed_context: &TypedContext, compiler: &Compiler) { for source_file in &typed_context.source_files() { - for func_def in source_file.function_definitions() { - compiler.visit_function_definition(&func_def, typed_context); + compile_definitions(&source_file.definitions, typed_context, compiler); + } +} + +fn compile_definitions( + definitions: &[Definition], + typed_context: &TypedContext, + compiler: &Compiler, +) { + for definition in definitions { + match definition { + Definition::Function(func_def) => { + compiler.visit_function_definition(func_def, typed_context); + } + Definition::Module(module_def) => { + if let Some(body) = module_def.body.borrow().as_ref() { + compile_definitions(body, typed_context, compiler); + } + } + _ => {} } } }