From 178a4ada8b0851359d8f47860e2b7252671238d8 Mon Sep 17 00:00:00 2001 From: zzzdong Date: Fri, 30 May 2025 18:38:27 +0800 Subject: [PATCH 1/2] wip: stage --- src/compiler.rs | 62 ++- src/compiler/ast.rs | 1 + src/compiler/ast/grammar.pest | 5 +- src/compiler/ast/syntax.rs | 122 +---- src/compiler/ast/walker.rs | 298 ++++++++++++ src/compiler/lowering.rs | 28 +- src/compiler/parser.rs | 94 ++-- src/compiler/semantic.rs | 583 +++++------------------ src/compiler/symbol.rs | 52 +++ src/compiler/typing.rs | 843 +++++++++++++++++++++++++++++----- src/runtime/environment.rs | 8 + 11 files changed, 1348 insertions(+), 748 deletions(-) create mode 100644 src/compiler/ast/walker.rs create mode 100644 src/compiler/symbol.rs diff --git a/src/compiler.rs b/src/compiler.rs index 868d6d6..6311925 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -5,19 +5,20 @@ mod lowering; mod parser; mod regalloc; mod semantic; +mod symbol; mod typing; use std::collections::HashMap; +use std::path::Path; use std::sync::Arc; use ir::builder::{InstBuilder, IrBuilder}; use ir::instruction::IrUnit; use log::debug; -use typing::TypeContext; +use typing::{Type, TypeChecker, TypeContext, TypeError}; use crate::Environment; use crate::bytecode::{Module, Register}; -use ast::syntax::{Span, Type}; use parser::ParseError; use codegen::Codegen; @@ -30,7 +31,9 @@ pub fn compile(script: &str, env: &crate::Environment) -> Result for CompileError { + fn from(error: std::io::Error) -> Self { + CompileError::Io(error) + } +} + impl From for CompileError { fn from(error: ParseError) -> Self { CompileError::Parse(error) } } +impl From for CompileError { + fn from(error: TypeError) -> Self { + CompileError::Type(error) + } +} + +impl From for CompileError { + fn from(error: SemanticsError) -> Self { + CompileError::Semantics(error) + } +} + impl std::fmt::Display for CompileError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + CompileError::Io(error) => write!(f, "IO error: {error}"), CompileError::Parse(error) => write!(f, "Parse error: {error}"), + CompileError::Type(error) => write!(f, "Type error: {error:?}"), CompileError::Semantics(message) => write!(f, "Semantics error: {message}"), CompileError::UndefinedVariable { name } => { write!(f, "Undefined variable `{name}`") @@ -129,6 +152,35 @@ impl std::fmt::Display for CompileError { impl std::error::Error for CompileError {} +pub struct FileId(usize); + +pub struct Context { + sources: Vec, +} + +impl Context { + pub fn new() -> Self { + Self { sources: vec![] } + } + + pub fn add_source(&mut self, source: String) -> FileId { + let id = FileId(self.sources.len()); + self.sources.push(source); + id + } + + pub fn add_file(&mut self, file: impl AsRef) -> Result { + let id = FileId(self.sources.len()); + let content = std::fs::read_to_string(file.as_ref())?; + self.sources.push(content); + Ok(id) + } + + pub fn get_source(&self, file: FileId) -> Option<&str> { + self.sources.get(file.0).map(|s| s.as_str()) + } +} + pub struct Compiler {} impl Default for Compiler { @@ -149,11 +201,11 @@ impl Compiler { debug!("AST: {ast:?}"); let mut type_cx = TypeContext::new(); - type_cx.process_env(env); + type_cx.analyze_type_def(&ast.stmts)?; // 语义分析 - let mut analyzer = SemanticAnalyzer::new(&mut type_cx); - analyzer.analyze_program(&mut ast, env)?; + let mut checker = SemanticChecker::new(&mut type_cx); + checker.check_program(&mut ast, env)?; // IR生成, AST -> IR let mut unit = IrUnit::new(); diff --git a/src/compiler/ast.rs b/src/compiler/ast.rs index 4a39d2c..41fa5ba 100644 --- a/src/compiler/ast.rs +++ b/src/compiler/ast.rs @@ -1 +1,2 @@ pub mod syntax; +pub mod walker; \ No newline at end of file diff --git a/src/compiler/ast/grammar.pest b/src/compiler/ast/grammar.pest index 758ae72..166a501 100644 --- a/src/compiler/ast/grammar.pest +++ b/src/compiler/ast/grammar.pest @@ -47,11 +47,8 @@ enum_item = { "enum" ~ identifier ~ "{" ~ (enum_item_list ~ ","?)? ~ "}" } enum_item_list = { enum_field ~ ("," ~ enum_field)* } -enum_field = { simple_enum_field | tuple_enum_field } +enum_field = { identifier ~ ("(" ~ type_expression ~ ")")?} -simple_enum_field = { identifier ~ !"(" } - -tuple_enum_field = { identifier ~ "(" ~ type_expression ~ ("," ~ type_expression)* ~ ")" } /// Expression Statment expression_statement = { expression ~ ";" } diff --git a/src/compiler/ast/syntax.rs b/src/compiler/ast/syntax.rs index 4ca6355..9cb5238 100644 --- a/src/compiler/ast/syntax.rs +++ b/src/compiler/ast/syntax.rs @@ -11,12 +11,11 @@ use pest::{RuleType, iterators::Pair}; pub struct AstNode { pub node: T, pub span: Span, - pub ty: Type, } impl AstNode { - pub fn new(node: T, span: Span, ty: Type) -> AstNode { - AstNode { span, node, ty } + pub fn new(node: T, span: Span) -> AstNode { + AstNode { span, node } } pub fn span(&self) -> Span { @@ -26,10 +25,6 @@ impl AstNode { pub fn node(&self) -> &T { &self.node } - - pub fn ty(&self) -> &Type { - &self.ty - } } impl AsRef for AstNode { @@ -55,78 +50,15 @@ impl Span { Span { start, end } } - pub fn from_pair(pair: &Pair) -> Span { - Span { - start: pair.as_span().start(), - end: pair.as_span().end(), - } - } - pub fn merge(&self, other: &Span) -> Span { Span { start: self.start.min(other.start), end: self.end.max(other.end), } } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum Type { - Any, - Boolean, - Byte, - Integer, - Float, - Char, - String, - Range, - Tuple(Vec), - Array(Box), - Map(Box), - Unknown, - UserDefined(String), - Decl(Declaration), -} - -impl Type { - pub fn is_boolean(&self) -> bool { - matches!(self, Type::Boolean) - } - - pub fn is_numeric(&self) -> bool { - matches!(self, Type::Byte | Type::Integer | Type::Float) - } - - pub fn is_string(&self) -> bool { - matches!(self, Type::String) - } - - pub fn is_unknown(&self) -> bool { - matches!(self, Type::Unknown) - } - - pub fn is_collection(&self) -> bool { - matches!(self, Type::Array(_) | Type::Map(_)) - } - - pub fn is_any(&self) -> bool { - matches!(self, Type::Any) - } - - pub fn get_array_element_type(&self) -> Option<&Type> { - if let Type::Array(ty) = self { - Some(ty.as_ref()) - } else { - None - } - } - pub fn get_map_value_type(&self) -> Option<&Type> { - if let Type::Map(ty) = self { - Some(ty.as_ref()) - } else { - None - } + pub fn is_empty(&self) -> bool { + self.start == self.end } } @@ -167,7 +99,7 @@ pub struct BlockStatement(pub Vec); pub enum ItemStatement { Enum(EnumItem), Struct(StructItem), - Fn(FunctionItem), + Function(FunctionItem), } #[derive(Debug, Clone, PartialEq)] @@ -177,9 +109,9 @@ pub struct EnumItem { } #[derive(Debug, Clone, PartialEq)] -pub enum EnumVariant { - Simple(String), - Tuple(String, Vec), +pub struct EnumVariant { + pub name: String, + pub variant: Option, } #[derive(Debug, Clone, PartialEq)] @@ -256,8 +188,8 @@ pub enum TypeExpression { String, Array(Box), Tuple(Vec), - Generic(String, Vec), UserDefined(String), + Generic(String, Vec), Impl(Box), } @@ -533,44 +465,12 @@ pub struct CallMethodExpression { #[derive(Debug, Clone, PartialEq)] pub struct StructExpression { - pub name: String, + pub name: AstNode, pub fields: Vec, } #[derive(Debug, Clone, PartialEq)] pub struct StructExprField { - pub name: String, + pub name: AstNode, pub value: ExpressionNode, } - -#[derive(Debug, Clone, PartialEq)] -pub enum Declaration { - Function(FunctionDeclaration), - Struct(StructDeclaration), - Enum(EnumDeclaration), -} - -impl Declaration { - pub fn is_function(&self) -> bool { - matches!(self, Declaration::Function(_)) - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct FunctionDeclaration { - pub name: String, - pub params: Vec<(String, Option)>, - pub return_type: Option>, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct StructDeclaration { - pub name: String, - pub fields: HashMap, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct EnumDeclaration { - pub name: String, - pub variants: Vec<(String, Option)>, -} diff --git a/src/compiler/ast/walker.rs b/src/compiler/ast/walker.rs new file mode 100644 index 0000000..40aa49d --- /dev/null +++ b/src/compiler/ast/walker.rs @@ -0,0 +1,298 @@ +use super::syntax::*; + +pub trait Walker { + type Error; + type StatementResult: Default; + type ExpressionResult: Default; + + fn walk_program(&mut self, program: &Program) -> Result<(), Self::Error> { + for stmt in &program.stmts { + self.walk_statement(stmt)?; + } + Ok(()) + } + + fn walk_statement( + &mut self, + stmt: &StatementNode, + ) -> Result { + match &stmt.node { + Statement::Empty => self.walk_empty(), + Statement::Break => self.walk_break(), + Statement::Continue => self.walk_continue(), + Statement::Block(stmt) => self.walk_block_statement(stmt), + Statement::Item(stmt) => self.walk_item_statement(stmt), + Statement::Let(stmt) => self.walk_let_statement(stmt), + Statement::For(stmt) => self.walk_for_statement(stmt), + Statement::While(stmt) => self.walk_while_statement(stmt), + Statement::Loop(stmt) => self.walk_loop_statement(stmt), + Statement::If(stmt) => self.walk_if_statement(stmt), + Statement::Return(stmt) => self.walk_return_statement(stmt), + Statement::Expression(expr) => self.walk_expression(expr).map(|_| Default::default()), + } + } + + fn walk_expression( + &mut self, + expr: &ExpressionNode, + ) -> Result { + match &expr.node { + Expression::Literal(expr) => self.walk_literal_expression(expr), + Expression::Identifier(expr) => self.walk_identifier_expression(expr), + Expression::Environment(expr) => self.walk_environment_expression(expr), + Expression::Path(expr) => self.walk_path_expression(expr), + Expression::Tuple(expr) => self.walk_tuple_expression(expr), + Expression::Array(expr) => self.walk_array_expression(expr), + Expression::Map(expr) => self.walk_map_expression(expr), + Expression::Closure(expr) => self.walk_closure_expression(expr), + Expression::Range(expr) => self.walk_range_expression(expr), + Expression::Slice(expr) => self.walk_slice_expression(expr), + Expression::Assign(expr) => self.walk_assign_expression(expr), + Expression::Call(expr) => self.walk_call_expression(expr), + Expression::Try(expr) => self.walk_expression(expr), + Expression::Await(expr) => self.walk_expression(expr), + Expression::Prefix(expr) => self.walk_prefix_expression(expr), + Expression::Binary(expr) => self.walk_binary_expression(expr), + Expression::IndexGet(expr) => self.walk_index_get_expression(expr), + Expression::IndexSet(expr) => self.walk_index_set_expression(expr), + Expression::PropertyGet(expr) => self.walk_property_get_expression(expr), + Expression::PropertySet(expr) => self.walk_property_set_expression(expr), + Expression::CallMethod(expr) => self.walk_call_method_expression(expr), + Expression::StructExpr(expr) => self.walk_struct_expression(expr), + } + } + + fn walk_empty(&mut self) -> Result { + Ok(Default::default()) + } + + fn walk_break(&mut self) -> Result { + Ok(Default::default()) + } + + fn walk_continue(&mut self) -> Result { + Ok(Default::default()) + } + + fn walk_block_statement( + &mut self, + stmt: &BlockStatement, + ) -> Result { + Ok(Default::default()) + } + + fn walk_item_statement( + &mut self, + stmt: &ItemStatement, + ) -> Result { + match stmt { + ItemStatement::Function(item) => { + self.walk_function_item(item)?; + Ok(Default::default()) + } + ItemStatement::Struct(item) => { + self.walk_struct_item(item)?; + Ok(Default::default()) + } + ItemStatement::Enum(item) => { + self.walk_enum_item(item)?; + Ok(Default::default()) + } + } + } + + fn walk_function_item(&mut self, item: &FunctionItem) -> Result<(), Self::Error> { + Ok(Default::default()) + } + + fn walk_struct_item(&mut self, item: &StructItem) -> Result<(), Self::Error> { + Ok(Default::default()) + } + + fn walk_enum_item(&mut self, item: &EnumItem) -> Result<(), Self::Error> { + Ok(Default::default()) + } + + fn walk_let_statement( + &mut self, + stmt: &LetStatement, + ) -> Result { + Ok(Default::default()) + } + + fn walk_for_statement( + &mut self, + stmt: &ForStatement, + ) -> Result { + Ok(Default::default()) + } + + fn walk_while_statement( + &mut self, + stmt: &WhileStatement, + ) -> Result { + Ok(Default::default()) + } + + fn walk_loop_statement( + &mut self, + stmt: &LoopStatement, + ) -> Result { + Ok(Default::default()) + } + + fn walk_if_statement( + &mut self, + stmt: &IfStatement, + ) -> Result { + Ok(Default::default()) + } + + fn walk_return_statement( + &mut self, + stmt: &ReturnStatement, + ) -> Result { + Ok(Default::default()) + } + + // 表达式遍历方法 + fn walk_literal_expression( + &mut self, + _expr: &LiteralExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_identifier_expression( + &mut self, + _expr: &IdentifierExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_environment_expression( + &mut self, + _expr: &EnvironmentExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_path_expression( + &mut self, + _expr: &PathExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_tuple_expression( + &mut self, + _expr: &TupleExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_array_expression( + &mut self, + _expr: &ArrayExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_map_expression( + &mut self, + _expr: &MapExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_closure_expression( + &mut self, + _expr: &ClosureExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_range_expression( + &mut self, + _expr: &RangeExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_slice_expression( + &mut self, + _expr: &SliceExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_assign_expression( + &mut self, + _expr: &AssignExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_call_expression( + &mut self, + _expr: &CallExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_prefix_expression( + &mut self, + _expr: &PrefixExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_binary_expression( + &mut self, + _expr: &BinaryExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_index_get_expression( + &mut self, + _expr: &IndexGetExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_index_set_expression( + &mut self, + _expr: &IndexSetExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_property_get_expression( + &mut self, + _expr: &PropertyGetExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_property_set_expression( + &mut self, + _expr: &PropertySetExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_call_method_expression( + &mut self, + _expr: &CallMethodExpression, + ) -> Result { + Ok(Default::default()) + } + + fn walk_struct_expression( + &mut self, + _expr: &StructExpression, + ) -> Result { + Ok(Default::default()) + } +} diff --git a/src/compiler/lowering.rs b/src/compiler/lowering.rs index a4f514b..cee1ef4 100644 --- a/src/compiler/lowering.rs +++ b/src/compiler/lowering.rs @@ -4,7 +4,8 @@ use std::{cell::RefCell, collections::BTreeMap, rc::Rc}; use super::CompileError; use super::ast::syntax::*; use super::ir::{builder::*, instruction::*}; -use super::typing::TypeContext; +use super::typing::{TypeContext, TypeDef}; +use crate::compiler::typing::{FunctionDef, StructDef}; use crate::{ Environment, bytecode::{FunctionId, Opcode, Primitive}, @@ -58,10 +59,8 @@ impl<'a> ASTLower<'a> { builder.switch_to_block(entry); // declare functions - for decl in self.type_cx.function_decls() { - if let Declaration::Function(func) = decl { - self.declare_function(func); - } + for func in self.type_cx.functions() { + self.declare_function(func); } let entry = self.create_block("main"); @@ -71,7 +70,7 @@ impl<'a> ASTLower<'a> { // split program into statements and items for stmt in prog.stmts { match stmt.node { - Statement::Item(ItemStatement::Fn(func)) => { + Statement::Item(ItemStatement::Function(func)) => { self.lower_function_item(func); } Statement::Item(_) => { @@ -145,7 +144,7 @@ impl<'a> ASTLower<'a> { fn lower_item_stmt(&mut self, item: ItemStatement) { match item { - ItemStatement::Fn(fn_item) => { + ItemStatement::Function(fn_item) => { self.lower_function_item(fn_item); } _ => unimplemented!("statement {:?}", item), @@ -458,11 +457,16 @@ impl<'a> ASTLower<'a> { let mut field_map = fields .into_iter() - .map(|StructExprField { name, value }| (name, self.lower_expression(value))) + .map(|StructExprField { name, value }| { + (name.node().clone(), self.lower_expression(value)) + }) .collect::>(); - let decl = self.type_cx.get_type_decl(&name).expect("struct not found"); - if let Declaration::Struct(StructDeclaration { + let decl = self + .type_cx + .get_type_def(&name.node()) + .expect("struct not found"); + if let TypeDef::Struct(StructDef { fields: decl_fields, .. }) = decl @@ -677,8 +681,8 @@ impl<'a> ASTLower<'a> { } } - fn declare_function(&mut self, func: &FunctionDeclaration) -> FunctionId { - let FunctionDeclaration { name, params, .. } = func; + fn declare_function(&mut self, func: &FunctionDef) -> FunctionId { + let FunctionDef { name, params, .. } = func; let func_sig = FuncSignature::new( name.clone(), diff --git a/src/compiler/parser.rs b/src/compiler/parser.rs index 0b7cac3..d0b3321 100644 --- a/src/compiler/parser.rs +++ b/src/compiler/parser.rs @@ -38,6 +38,12 @@ impl std::error::Error for ParseError {} type Result = std::result::Result; +impl Span { + fn from_pair(pair: &Pair) -> Self { + Span::new(pair.as_span().start(), pair.as_span().end()) + } +} + #[derive(pest_derive::Parser)] #[grammar = "compiler/ast/grammar.pest"] struct PestParser; @@ -164,7 +170,7 @@ fn parse_statement(pair: Pair) -> Result { _ => unreachable!("unknown statement: {pair:?}"), }; - Ok(AstNode::new(stmt, span, Type::Unknown)) + Ok(AstNode::new(stmt, span)) } fn parse_item_statement(pair: Pair) -> Result { @@ -173,7 +179,7 @@ fn parse_item_statement(pair: Pair) -> Result { match stat.as_rule() { Rule::enum_item => Ok(ItemStatement::Enum(parse_enum_item(stat))), Rule::struct_item => Ok(ItemStatement::Struct(parse_struct_item(stat))), - Rule::fn_item => parse_function_item(stat).map(ItemStatement::Fn), + Rule::fn_item => parse_function_item(stat).map(ItemStatement::Function), _ => unreachable!("unknown item statement: {stat:?}"), } } @@ -191,21 +197,15 @@ fn parse_enum_item(pair: Pair) -> EnumItem { } fn parse_enum_field(pair: Pair) -> EnumVariant { - let field = pair.into_inner().next().unwrap(); - match field.as_rule() { - Rule::simple_enum_field => EnumVariant::Simple(field.as_str().to_string()), - Rule::tuple_enum_field => parse_tuple_enum_field(field), - _ => unreachable!("{field:?}"), - } -} - -fn parse_tuple_enum_field(pair: Pair) -> EnumVariant { let mut pairs = pair.into_inner(); - let name = pairs.next().unwrap().as_str().to_string(); - let tuple = pairs.map(|item| parse_type_expression(item)).collect(); + let name = pairs.next().unwrap(); + let ty = pairs.next().map(|item| parse_type_expression(item)); - EnumVariant::Tuple(name, tuple) + EnumVariant { + name: name.as_str().to_string(), + variant: ty, + } } fn parse_struct_item(pair: Pair) -> StructItem { @@ -327,9 +327,8 @@ fn parse_if_statement(pair: Pair) -> Result { let span = Span::from_pair(&pair); match pair.as_rule() { Rule::block => parse_block(pair), - Rule::if_statement => parse_if_statement(pair).map(|item| { - BlockStatement(vec![AstNode::new(Statement::If(item), span, Type::Unknown)]) - }), + Rule::if_statement => parse_if_statement(pair) + .map(|item| BlockStatement(vec![AstNode::new(Statement::If(item), span)])), _ => unreachable!("unknown else_branch: {:?}", pair), } }) @@ -451,7 +450,7 @@ fn parse_prefix(op: Pair, rhs: Result) -> Result, op: Pair) -> Result, op: Pair) -> Result, op: Pair) -> Result, op: Pair) -> Result unreachable!("unknown postfix: {:?}", op), }; - Ok(AstNode::new(expr, span, Type::Unknown)) + Ok(AstNode::new(expr, span)) } fn parse_struct_expression(pair: Pair) -> Result { @@ -693,6 +692,7 @@ fn parse_struct_expression(pair: Pair) -> Result { let mut pairs = pair.into_inner(); let name = pairs.next().unwrap(); + let name_span = Span::from_pair(&name); assert_eq!(name.as_rule(), Rule::identifier); let name = name.as_str().to_string(); @@ -703,15 +703,11 @@ fn parse_struct_expression(pair: Pair) -> Result { .collect::>>()?; let expr = StructExpression { - name: name.clone(), + name: AstNode::new(name, name_span), fields, }; - Ok(AstNode::new( - Expression::StructExpr(expr), - span, - Type::Unknown, - )) + Ok(AstNode::new(Expression::StructExpr(expr), span)) } fn parse_struct_expression_field(pair: Pair) -> Result { @@ -721,8 +717,7 @@ fn parse_struct_expression_field(pair: Pair) -> Result { let name = pairs.next().unwrap(); assert_eq!(name.as_rule(), Rule::identifier); - let name = name.as_str().to_string(); - + let name = AstNode::new(name.as_str().to_string(), Span::from_pair(&name)); let value = parse_expression(pairs.next().unwrap())?; Ok(StructExprField { name, value }) @@ -744,7 +739,7 @@ fn parse_atom(pair: Pair) -> Result { _ => unreachable!("unknown atom: {:?}", pair), }?; - Ok(AstNode::new(expr, span, Type::Unknown)) + Ok(AstNode::new(expr, span)) } fn parse_closure(pair: Pair) -> Result { @@ -1525,7 +1520,7 @@ mod test { #[test] fn test_enum_item() { - let input = r#"enum A { AA, BB(int, float) }"#; + let input = r#"enum A { AA, BB(int) }"#; let item_statement = parse_statement_input(input).unwrap(); if let Statement::Item(ItemStatement::Enum(enum_item)) = item_statement.node { // 检查枚举名称 @@ -1535,17 +1530,22 @@ mod test { assert_eq!(enum_item.variants.len(), 2); // 检查第一个变体(简单变体) - assert_eq!(enum_item.variants[0], EnumVariant::Simple("AA".to_string())); - - // 检查第二个变体(元组变体) - if let EnumVariant::Tuple(name, types) = &enum_item.variants[1] { - assert_eq!(name, "BB"); - assert_eq!(types.len(), 2); - assert_eq!(types[0], TypeExpression::Integer); - assert_eq!(types[1], TypeExpression::Float); - } else { - panic!("Expected tuple variant"); - } + assert_eq!( + enum_item.variants[0], + EnumVariant { + name: "AA".to_string(), + variant: None + } + ); + + // 检查第二个变体(值变体) + assert_eq!( + enum_item.variants[0], + EnumVariant { + name: "BB".to_string(), + variant: Some(TypeExpression::Integer) + } + ); } else { panic!("Expected enum item statement"); } @@ -1578,7 +1578,7 @@ mod test { fn test_function_item() { let input = r#"fn A(a: int, b: float) -> int { return a + b; }"#; let item_statement = parse_statement_input(input).unwrap(); - if let Statement::Item(ItemStatement::Fn(function)) = item_statement.node { + if let Statement::Item(ItemStatement::Function(function)) = item_statement.node { // 检查函数名和参数 assert_eq!(function.name, "A"); assert_eq!(function.params.len(), 2); diff --git a/src/compiler/semantic.rs b/src/compiler/semantic.rs index 8d69be5..4594b38 100644 --- a/src/compiler/semantic.rs +++ b/src/compiler/semantic.rs @@ -1,22 +1,21 @@ use super::CompileError; use super::ast::syntax::*; +use super::symbol::SymbolTable; use super::typing::TypeContext; use crate::Environment; pub struct SemanticAnalyzer<'a> { - type_cx: &'a mut TypeContext, - // 当前函数的返回类型,用于检查return语句 - current_function_return_type: Option, - // 循环嵌套计数,用于检查break和continue语句 + type_cx: &'a TypeContext, loop_depth: usize, + symbol_table: SymbolTable<()>, } impl<'a> SemanticAnalyzer<'a> { - pub fn new(type_cx: &'a mut TypeContext) -> Self { + pub fn new(type_cx: &'a TypeContext) -> Self { SemanticAnalyzer { type_cx, - current_function_return_type: None, loop_depth: 0, + symbol_table: SymbolTable::new(), } } @@ -28,13 +27,10 @@ impl<'a> SemanticAnalyzer<'a> { ) -> Result<(), CompileError> { // 第一阶段:收集环境变量 for name in env.symbols.keys() { - self.type_cx.set_type(name.to_string(), Type::Any); + self.symbol_table.insert(name.clone(), ()); } - // 第二阶段:收集声明 - self.type_cx.analyze_type_decl(&program.stmts); - - // 第三阶段:分析所有语句 + // 第二阶段:分析所有语句 for stmt in &mut program.stmts { self.analyze_statement(stmt)?; } @@ -42,10 +38,10 @@ impl<'a> SemanticAnalyzer<'a> { } /// 分析语句并推断类型 - fn analyze_statement(&mut self, stmt: &mut StatementNode) -> Result<(), CompileError> { + fn analyze_statement(&mut self, stmt: &StatementNode) -> Result<(), CompileError> { let span = stmt.span; - match &mut stmt.node { + match &stmt.node { Statement::Expression(expr) => self.analyze_expression(expr).map(|_| ()), Statement::Let(let_stmt) => self.analyze_let_statement(let_stmt), Statement::If(if_stmt) => self.analyze_if_statement(if_stmt), @@ -57,7 +53,7 @@ impl<'a> SemanticAnalyzer<'a> { Statement::Break => self.analyze_break_statement(span), Statement::Continue => self.analyze_continue_statement(span), Statement::Empty => Ok(()), - Statement::Item(ItemStatement::Fn(func)) => self.analyze_function_item(func), + Statement::Item(ItemStatement::Function(func)) => self.analyze_function_item(func), Statement::Item(ItemStatement::Struct(item)) => self.analyze_struct_item(item), Statement::Item(ItemStatement::Enum(EnumItem { .. })) => { unimplemented!("EnumItem not implemented") @@ -65,84 +61,36 @@ impl<'a> SemanticAnalyzer<'a> { } } - fn analyze_let_statement(&mut self, let_stmt: &mut LetStatement) -> Result<(), CompileError> { + fn analyze_let_statement(&mut self, let_stmt: &LetStatement) -> Result<(), CompileError> { let LetStatement { name, ty, value } = let_stmt; - let decl_ty = match ty { - Some(type_expr) => { - let decl_type = self.type_from_type_expr(type_expr)?; - Some(decl_type) - } - None => None, - }; - - let value_ty = match value { - Some(value) => { - let value_ty = self.analyze_expression(value)?; - Some(value_ty) - } - None => None, - }; - - match (decl_ty, value_ty) { - (Some(decl_type), Some(value_ty)) => { - if decl_type != value_ty { - return Err(CompileError::type_mismatch( - decl_type, - value_ty, - value.clone().unwrap().span(), - )); - } - self.type_cx.set_type(name.clone(), decl_type); - } - (Some(decl_type), None) => { - self.type_cx.set_type(name.clone(), decl_type); - } - (None, Some(value_ty)) => { - self.type_cx.set_type(name.clone(), value_ty); - } - (None, None) => { - self.type_cx.set_type(name.clone(), Type::Any); - } - } + self.symbol_table.insert(name, ()); Ok(()) } /// 分析If语句 - fn analyze_if_statement(&mut self, if_stmt: &mut IfStatement) -> Result<(), CompileError> { + fn analyze_if_statement(&mut self, if_stmt: &IfStatement) -> Result<(), CompileError> { // 分析条件表达式 - let condition_ty = self.analyze_expression(&mut if_stmt.condition)?; - - // TODO: 条件必须是布尔类型 - if condition_ty != Type::Boolean && condition_ty != Type::Any { - return Err(CompileError::type_mismatch( - Type::Boolean, - condition_ty.clone(), - if_stmt.condition.span, - )); - } + self.analyze_expression(&if_stmt.condition)?; // 分析then分支 - self.analyze_block(&mut if_stmt.then_branch)?; + self.analyze_block(&if_stmt.then_branch)?; // 分析else分支(如果有) - if let Some(else_branch) = &mut if_stmt.else_branch { + if let Some(else_branch) = &if_stmt.else_branch { self.analyze_block(else_branch)?; } Ok(()) } - fn analyze_loop_statement( - &mut self, - loop_stmt: &mut LoopStatement, - ) -> Result<(), CompileError> { + fn analyze_loop_statement(&mut self, loop_stmt: &LoopStatement) -> Result<(), CompileError> { // 增加循环深度 self.loop_depth += 1; // 分析循环体 - self.analyze_block(&mut loop_stmt.body)?; + self.analyze_block(&loop_stmt.body)?; // 减少循环深度 self.loop_depth -= 1; @@ -151,27 +99,15 @@ impl<'a> SemanticAnalyzer<'a> { } /// 分析While语句 - fn analyze_while_statement( - &mut self, - while_stmt: &mut WhileStatement, - ) -> Result<(), CompileError> { + fn analyze_while_statement(&mut self, while_stmt: &WhileStatement) -> Result<(), CompileError> { // 分析条件表达式 - let condition_ty = self.analyze_expression(&mut while_stmt.condition)?; - - // 条件必须是布尔类型 - if !condition_ty.is_boolean() { - return Err(CompileError::type_mismatch( - Type::Boolean, - condition_ty.clone(), - while_stmt.condition.span, - )); - } + self.analyze_expression(&while_stmt.condition)?; // 增加循环深度 self.loop_depth += 1; // 分析循环体 - self.analyze_block(&mut while_stmt.body)?; + self.analyze_block(&while_stmt.body)?; // 减少循环深度 self.loop_depth -= 1; @@ -180,51 +116,37 @@ impl<'a> SemanticAnalyzer<'a> { } /// 分析For语句 - fn analyze_for_statement(&mut self, for_stmt: &mut ForStatement) -> Result<(), CompileError> { + fn analyze_for_statement(&mut self, for_stmt: &ForStatement) -> Result<(), CompileError> { // 创建新的作用域 - let old_env = std::mem::replace(self.type_cx, self.type_cx.clone()); + self.symbol_table.enter_scope(); // 分析迭代器表达式 // self.analyze_expression(&mut for_stmt.iterable)?; - self.analyze_pattern(&mut for_stmt.pat, &mut for_stmt.iterable)?; + self.analyze_pattern(&for_stmt.pat, &for_stmt.iterable)?; // 增加循环深度 self.loop_depth += 1; // 分析循环体 - self.analyze_block(&mut for_stmt.body)?; + self.analyze_block(&for_stmt.body)?; // 减少循环深度 self.loop_depth -= 1; // 恢复作用域 - let _ = std::mem::replace(self.type_cx, old_env); + self.symbol_table.leave_scope(); Ok(()) } fn analyze_pattern( &mut self, - pattern: &mut Pattern, - expr: &mut ExpressionNode, + pattern: &Pattern, + expr: &ExpressionNode, ) -> Result<(), CompileError> { self.analyze_expression(expr)?; - // 检查模式是否匹配表达式 - match pattern { - Pattern::Wildcard => {} - Pattern::Identifier(id) => { - self.type_cx.set_type(id.clone(), Type::Any); - } - Pattern::Tuple(tuple) => { - for pat in tuple.iter_mut() { - self.analyze_pattern(pat, expr)?; - } - } - Pattern::Literal(_literal) => {} - } - Ok(()) } @@ -247,74 +169,57 @@ impl<'a> SemanticAnalyzer<'a> { /// 分析Return语句 fn analyze_return_statement( &mut self, - return_stmt: &mut ReturnStatement, + return_stmt: &ReturnStatement, _span: Span, ) -> Result<(), CompileError> { - if let Some(expr) = &mut return_stmt.value { + if let Some(expr) = &return_stmt.value { self.analyze_expression(expr)?; } Ok(()) } /// 分析代码块 - fn analyze_block(&mut self, block: &mut BlockStatement) -> Result<(), CompileError> { + fn analyze_block(&mut self, block: &BlockStatement) -> Result<(), CompileError> { // 创建新的作用域 - let old_env = std::mem::replace(self.type_cx, self.type_cx.clone()); + self.symbol_table.enter_scope(); // 分析块中的所有语句 - for stmt in &mut block.0 { + for stmt in &block.0 { self.analyze_statement(stmt)?; } // 恢复作用域 - let _ = std::mem::replace(self.type_cx, old_env); + self.symbol_table.leave_scope(); Ok(()) } /// 分析函数定义 - fn analyze_function_item(&mut self, func: &mut FunctionItem) -> Result<(), CompileError> { + fn analyze_function_item(&mut self, func: &FunctionItem) -> Result<(), CompileError> { // 创建新的作用域 - let old_env = std::mem::replace(self.type_cx, self.type_cx.clone()); - - // 获取函数返回类型 - let return_ty = if let Some(ty_expr) = &func.return_ty { - self.type_from_type_expr(ty_expr)? - } else { - Type::Any - }; + self.symbol_table.enter_scope(); // 添加参数到环境 for param in &func.params { - let param_ty = match ¶m.ty { - Some(ty) => self.type_from_type_expr(ty)?, - None => Type::Any, - }; - - self.type_cx.set_type(param.name.clone(), param_ty); + self.symbol_table.insert(¶m.name, ()); } - // 保存当前函数的返回类型,用于检查return语句 - let old_return_type = self.current_function_return_type.replace(return_ty.clone()); - // 分析函数体 - self.analyze_block(&mut func.body)?; + self.analyze_block(&func.body)?; // 恢复环境 - let _ = std::mem::replace(self.type_cx, old_env); - self.current_function_return_type = old_return_type; + self.symbol_table.leave_scope(); Ok(()) } - fn analyze_struct_item(&mut self, _item: &mut StructItem) -> Result<(), CompileError> { + fn analyze_struct_item(&mut self, _item: &StructItem) -> Result<(), CompileError> { Ok(()) } /// 分析表达式并推断类型 - fn analyze_expression(&mut self, expr: &mut ExpressionNode) -> Result { - let ty = match &mut expr.node { - Expression::Literal(lit) => self.analyze_literal_expression(lit)?, + fn analyze_expression(&mut self, expr: &ExpressionNode) -> Result<(), CompileError> { + match &expr.node { Expression::Identifier(ident) => self.anlyze_identifier_expression(ident)?, Expression::Binary(expr) => self.analyze_binary_expression(expr)?, Expression::Prefix(expr) => self.analyze_prefix_expression(expr)?, @@ -333,403 +238,153 @@ impl<'a> SemanticAnalyzer<'a> { Expression::CallMethod(expr) => self.analyze_call_method_expression(expr)?, _ => { // 处理其他未实现的表达式类型 - Type::Any } }; - expr.ty = ty.clone(); - - Ok(ty) - } - - fn analyze_literal_expression( - &mut self, - lit: &mut LiteralExpression, - ) -> Result { - Ok(match lit { - LiteralExpression::Null => Type::Any, - LiteralExpression::Boolean(_) => Type::Boolean, - LiteralExpression::Integer(_) => Type::Integer, - LiteralExpression::Float(_) => Type::Float, - LiteralExpression::Char(_) => Type::Char, - LiteralExpression::String(_) => Type::String, - }) + Ok(()) } fn anlyze_identifier_expression( &mut self, - ident: &mut IdentifierExpression, - ) -> Result { - if let Some(ty) = self.type_cx.get_type(&ident.0) { - return Ok(ty.clone()); + ident: &IdentifierExpression, + ) -> Result<(), CompileError> { + if !self.type_cx.type_is_defined(&ident.0) { + return Err(CompileError::UndefinedVariable { + name: ident.0.clone(), + }); } - Err(CompileError::UndefinedVariable { - name: ident.0.clone(), - }) + Ok(()) } - fn analyze_binary_expression( - &mut self, - expr: &mut BinaryExpression, - ) -> Result { - let lhs_ty = self.analyze_expression(expr.lhs.as_mut())?; - let rhs_ty = self.analyze_expression(expr.rhs.as_mut())?; - - if lhs_ty == Type::Any || rhs_ty == Type::Any { - return Ok(Type::Any); // Object类型可以和任何类型比较 - } + fn analyze_binary_expression(&mut self, expr: &BinaryExpression) -> Result<(), CompileError> { + let lhs_ty = self.analyze_expression(&expr.lhs)?; + let rhs_ty = self.analyze_expression(&expr.rhs)?; - match expr.op { - BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem => { - if (!lhs_ty.is_numeric() || rhs_ty != lhs_ty) - && lhs_ty != Type::String - && lhs_ty != Type::Char - { - return Err(CompileError::type_mismatch( - Type::Integer, - lhs_ty, - expr.lhs.span(), - )); - } - } - BinOp::LogicAnd | BinOp::LogicOr => { - if lhs_ty != Type::Boolean || rhs_ty != Type::Boolean { - return Err(CompileError::type_mismatch( - Type::Boolean, - lhs_ty, - expr.lhs.span(), - )); - } - } - _ => {} - } - - Ok(match expr.op { - BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem => lhs_ty, - BinOp::LogicAnd | BinOp::LogicOr => Type::Boolean, - BinOp::Less - | BinOp::LessEqual - | BinOp::Greater - | BinOp::GreaterEqual - | BinOp::Equal - | BinOp::NotEqual => Type::Boolean, - _ => Type::Any, - }) + Ok(()) } - fn analyze_prefix_expression( - &mut self, - expr: &mut PrefixExpression, - ) -> Result { - let rhs_ty = self.analyze_expression(expr.rhs.as_mut())?; - - match expr.op { - PrefixOp::Neg => { - if !(rhs_ty.is_numeric() || rhs_ty == Type::Any) { - return Err(CompileError::type_mismatch( - Type::Integer, - rhs_ty, - expr.rhs.span(), - )); - } - } - PrefixOp::Not => { - if !(rhs_ty.is_boolean() || rhs_ty == Type::Any) { - return Err(CompileError::type_mismatch( - Type::Boolean, - rhs_ty, - expr.rhs.span(), - )); - } - } - } + fn analyze_prefix_expression(&mut self, expr: &PrefixExpression) -> Result<(), CompileError> { + let rhs_ty = self.analyze_expression(&expr.rhs)?; - Ok(match expr.op { - PrefixOp::Neg => rhs_ty, - PrefixOp::Not => Type::Boolean, - }) + Ok(()) } - fn analyze_call_expression(&mut self, expr: &mut CallExpression) -> Result { - let func_ty = self.analyze_expression(expr.func.as_mut())?; + fn analyze_call_expression(&mut self, expr: &CallExpression) -> Result<(), CompileError> { + self.analyze_expression(&expr.func)?; - if func_ty == Type::Any { - return Ok(Type::Any); + for arg in expr.args.iter() { + self.analyze_expression(&arg)?; } - if let Type::Decl(Declaration::Function(FunctionDeclaration { - name, - params, - return_type, - })) = &func_ty - { - if params.len() != expr.args.len() { - return Err(CompileError::ArgumentCountMismatch { - expected: params.len(), - actual: expr.args.len(), - }); - } - - for (arg, param_ty) in expr.args.iter_mut().zip(params.iter()) { - let arg_ty = self.analyze_expression(arg)?; - if let Some(ty) = ¶m_ty.1 { - if arg_ty != *ty && arg_ty != Type::Any { - return Err(CompileError::type_mismatch(ty.clone(), arg_ty, arg.span())); - } - } - } - - Ok(return_type - .as_ref() - .map(|t| *t.clone()) - .unwrap_or(Type::Any)) - } else { - Err(CompileError::NotCallable { - ty: func_ty, - span: expr.func.span(), - }) - } + Ok(()) } - fn analyze_array_expression( - &mut self, - expr: &mut ArrayExpression, - ) -> Result { - if expr.0.is_empty() { - return Ok(Type::Array(Box::new(Type::Any))); - } - - let mut elem_types = Vec::with_capacity(expr.0.len()); - + fn analyze_array_expression(&mut self, expr: &ArrayExpression) -> Result<(), CompileError> { // 为每个元素创建临时变量并分析类型 - for elem in expr.0.iter_mut() { - let elem_ty = self.analyze_expression(elem)?; - elem_types.push(elem_ty); - } - - // 检查数组元素是否类型一致 - let first_ty = &elem_types[0]; - for (i, elem) in expr.0.iter().enumerate() { - if &elem_types[i] != first_ty && elem_types[i] != Type::Any { - return Err(CompileError::type_mismatch( - first_ty.clone(), - elem_types[i].clone(), - elem.span(), - )); - } + for elem in expr.0.iter() { + self.analyze_expression(elem)?; } - Ok(Type::Array(Box::new(first_ty.clone()))) + Ok(()) } - fn analyze_map_expression(&mut self, _expr: &mut MapExpression) -> Result { - Ok(Type::Map(Box::new(Type::Any))) + fn analyze_map_expression(&mut self, expr: &MapExpression) -> Result<(), CompileError> { + for (key, value) in expr.0.iter() { + self.analyze_expression(key)?; + self.analyze_expression(value)?; + } + + Ok(()) } fn analyze_index_get_expression( &mut self, - expr: &mut IndexGetExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - let _index_ty = self.analyze_expression(expr.index.as_mut())?; - - // 根据集合类型确定返回类型 - match object_ty { - Type::Array(ty) => Ok(*ty.clone()), - Type::Map(value_ty) => Ok(*value_ty.clone()), - Type::Any => Ok(Type::Any), - _ => { - // 其他不支持的集合类型 - Err(CompileError::InvalidOperation { - message: format!("Cannot index non-array and non-map type({object_ty:?})"), - }) - } - } + expr: &IndexGetExpression, + ) -> Result<(), CompileError> { + self.analyze_expression(&expr.object)?; + self.analyze_expression(&expr.index)?; + + Ok(()) } fn analyze_index_set_expression( &mut self, - expr: &mut IndexSetExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - let value_ty = self.analyze_expression(expr.value.as_mut())?; - // index 不需要分析 - - match object_ty { - Type::Array(ty) => { - if *ty != value_ty && value_ty != Type::Any { - return Err(CompileError::type_mismatch( - *ty, - value_ty, - expr.value.span(), - )); - } - } - Type::Map(_) | Type::Any => { - // 对于Map,Any对象,允许设置任何属性 - } - _ => { - // 其他不支持的集合类型 - return Err(CompileError::InvalidOperation { - message: "Cannot index non-array and non-map type".to_string(), - }); - } - } + expr: &IndexSetExpression, + ) -> Result<(), CompileError> { + self.analyze_expression(&expr.object)?; + self.analyze_expression(&expr.value)?; - Ok(Type::Any) + Ok(()) } fn analyze_property_get_expression( &mut self, - expr: &mut PropertyGetExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - - match object_ty { - Type::Map(value_ty) if *value_ty == Type::String => Ok(*value_ty), - Type::Any => { - // 对于通用对象,允许访问任何属性 - Ok(Type::Any) - } - _ => { - // 其他不支持的集合类型 - Err(CompileError::InvalidOperation { - message: "Cannot access property on non-object type".to_string(), - }) - } - } + expr: &PropertyGetExpression, + ) -> Result<(), CompileError> { + self.analyze_expression(&expr.object)?; + Ok(()) } fn analyze_property_set_expression( &mut self, - expr: &mut PropertySetExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - let _property_ty = self.analyze_expression(expr.value.as_mut())?; - - match object_ty { - Type::Map(value_ty) if *value_ty == Type::String => Ok(Type::Any), - Type::Any => Ok(Type::Any), // 对于Any对象,允许设置任何属性 - _ => { - // 其他不支持的集合类型 - Err(CompileError::InvalidOperation { - message: "Cannot access property on non-object type".to_string(), - }) - } - } + expr: &PropertySetExpression, + ) -> Result<(), CompileError> { + self.analyze_expression(&expr.object)?; + self.analyze_expression(&expr.value)?; + + Ok(()) } - fn analyze_assign_expression( - &mut self, - expr: &mut AssignExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - let value_ty = self.analyze_expression(expr.value.as_mut())?; - - // 检查赋值左右类型是否兼容 - if object_ty != Type::Any && value_ty != Type::Any && object_ty != value_ty { - return Err(CompileError::type_mismatch( - object_ty, - value_ty, - expr.value.span(), - )); - } + fn analyze_assign_expression(&mut self, expr: &AssignExpression) -> Result<(), CompileError> { + self.analyze_expression(&expr.object)?; + self.analyze_expression(&expr.value)?; - Ok(object_ty) + Ok(()) } - fn analyze_range_expression( - &mut self, - expr: &mut RangeExpression, - ) -> Result { - if let Some(ref mut begin_expr) = expr.begin { - let begin_ty = self.analyze_expression(begin_expr)?; - if begin_ty != Type::Integer && begin_ty != Type::Any { - return Err(CompileError::type_mismatch( - Type::Integer, - begin_ty, - begin_expr.span(), - )); - } + fn analyze_range_expression(&mut self, expr: &RangeExpression) -> Result<(), CompileError> { + if let Some(ref begin_expr) = expr.begin { + self.analyze_expression(begin_expr)?; } - if let Some(ref mut end_expr) = expr.end { - let end_ty = self.analyze_expression(end_expr)?; - if end_ty != Type::Integer && end_ty != Type::Any { - return Err(CompileError::type_mismatch( - Type::Integer, - end_ty, - end_expr.span(), - )); - } + if let Some(ref end_expr) = expr.end { + self.analyze_expression(end_expr)? } - Ok(Type::Range) + Ok(()) } - fn analyze_slice_expression( - &mut self, - expr: &mut SliceExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - self.analyze_range_expression(&mut expr.range.node)?; + fn analyze_slice_expression(&mut self, expr: &SliceExpression) -> Result<(), CompileError> { + self.analyze_expression(&expr.object)?; + self.analyze_range_expression(&expr.range.node)?; - // 切片表达式的类型与原始对象类型相同 - Ok(object_ty) + Ok(()) } - fn analyze_try_expression(&mut self, expr: &mut ExpressionNode) -> Result { - let expr_ty = self.analyze_expression(expr)?; - if expr_ty != Type::Any { - return Err(CompileError::InvalidOperation { - message: "Cannot try non-result type".to_string(), - }); - } + fn analyze_try_expression(&mut self, expr: &ExpressionNode) -> Result<(), CompileError> { + self.analyze_expression(expr)?; - // Try表达式的类型与内部表达式类型相同 - Ok(expr_ty) + Ok(()) } - fn analyze_await_expression( - &mut self, - expr: &mut ExpressionNode, - ) -> Result { - let expr_ty = self.analyze_expression(expr)?; - if expr_ty != Type::Any { - return Err(CompileError::InvalidOperation { - message: "Cannot await non-promise type".to_string(), - }); - } + fn analyze_await_expression(&mut self, expr: &ExpressionNode) -> Result<(), CompileError> { + self.analyze_expression(expr)?; - // Await表达式的类型与内部表达式类型相同 - Ok(expr_ty) + Ok(()) } fn analyze_call_method_expression( &mut self, - expr: &mut CallMethodExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; + expr: &CallMethodExpression, + ) -> Result<(), CompileError> { + self.analyze_expression(&expr.object)?; // 分析方法参数 - for arg in &mut expr.args { + for arg in &expr.args { self.analyze_expression(arg)?; } - // 方法调用的结果类型取决于方法实现,这里简单处理为Any类型 - Ok(Type::Any) - } - - /// 从类型表达式转换为Type - fn type_from_type_expr(&mut self, type_expr: &TypeExpression) -> Result { - let ty = self.type_cx.resolve_type_decl(type_expr); - if ty == Type::Unknown { - Err(CompileError::UnknownType { - name: format!("{type_expr:?}"), - }) - } else { - Ok(ty) - } + Ok(()) } } diff --git a/src/compiler/symbol.rs b/src/compiler/symbol.rs new file mode 100644 index 0000000..9a27f7c --- /dev/null +++ b/src/compiler/symbol.rs @@ -0,0 +1,52 @@ +use std::collections::HashMap; + +#[derive(Debug)] +pub struct SymbolTable { + scopes: Vec>, +} + +impl SymbolTable { + pub fn new() -> Self { + SymbolTable:: { + scopes: vec![Scope::::new()], + } + } + + pub fn lookup(&self, name: impl AsRef) -> Option<&T> { + for scope in self.scopes.iter().rev() { + if let Some(ty) = scope.variables.get(name.as_ref()) { + return Some(ty); + } + } + None + } + + pub fn insert(&mut self, name: impl Into, value: T) { + self.scopes + .last_mut() + .unwrap() + .variables + .insert(name.into(), value); + } + + pub fn enter_scope(&mut self) { + self.scopes.push(Scope::::new()); + } + + pub fn leave_scope(&mut self) { + self.scopes.pop(); + } +} + +#[derive(Debug)] +struct Scope { + variables: HashMap, +} + +impl Scope { + fn new() -> Self { + Scope { + variables: HashMap::new(), + } + } +} diff --git a/src/compiler/typing.rs b/src/compiler/typing.rs index a7ccdd3..5781281 100644 --- a/src/compiler/typing.rs +++ b/src/compiler/typing.rs @@ -1,107 +1,322 @@ -use std::collections::HashMap; +use std::{collections::HashMap, default}; -use crate::Environment; +use crate::{Environment, compiler::symbol::SymbolTable}; -use super::ast::syntax::*; +use super::ast::{syntax::*, walker::Walker}; #[derive(Debug, Clone)] +pub struct TypeError { + pub span: Span, + pub kind: ErrKind, +} + +impl TypeError { + pub fn new(span: Span, kind: ErrKind) -> TypeError { + TypeError { span, kind } + } + + pub fn with_span(mut self, span: Span) -> Self { + self.span = span; + self + } +} + +impl From for TypeError { + fn from(value: ErrKind) -> Self { + TypeError { + span: Span::new(0, 0), + kind: value, + } + } +} + +#[derive(Debug, Clone)] +enum ErrKind { + Message(String), + UnresovledType(String), + DuplicateName(String), + TypeMismatch { expected: Type, actual: Type }, +} + +impl ErrKind { + pub fn with_span(self, span: Span) -> TypeError { + TypeError { span, kind: self } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct TypeId(usize); + +#[derive(Debug, Clone, PartialEq, Default)] +pub enum Type { + Boolean, + Byte, + Integer, + Float, + Char, + String, + Array, + Tuple, + Enum(TypeId), + Struct(TypeId), + Function(Box), + Any, + #[default] + Unknown, +} + +impl Type { + pub fn is_any(&self) -> bool { + matches!(self, Type::Any) + } + pub fn is_boolean(&self) -> bool { + matches!(self, Type::Boolean) + } + + pub fn is_numeric(&self) -> bool { + matches!(self, Type::Byte | Type::Integer | Type::Float) + } + + pub fn is_string(&self) -> bool { + matches!(self, Type::String) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum TypeDef { + Struct(StructDef), + Enum(EnumDef), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct FunctionDef { + pub name: String, + pub params: Vec<(String, Option)>, + pub return_type: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct EnumDef { + pub name: String, + pub variants: Vec<(String, Option)>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct StructDef { + pub name: String, + pub fields: HashMap, +} + pub struct TypeContext { - type_decls: HashMap, - type_env: HashMap, - // 新增:用于缓存已解析的类型声明 - resolved_types: HashMap, + type_defs: HashMap, + name_to_id: HashMap, + + functions: HashMap>, + + next_id: usize, } impl TypeContext { pub fn new() -> Self { - TypeContext { - type_decls: HashMap::new(), - type_env: HashMap::new(), - resolved_types: HashMap::new(), // 初始化缓存 + Self { + type_defs: HashMap::new(), + name_to_id: HashMap::new(), + functions: HashMap::new(), + next_id: 0, } } - pub fn add_type_decl(&mut self, name: String, decl: Declaration) { - self.type_decls.insert(name.clone(), decl.clone()); + pub fn functions(&self) -> impl Iterator { + self.functions.values().map(|func| &**func) + } - if matches!(&decl, &Declaration::Function(_)) { - self.type_env.insert(name.to_string(), Type::Decl(decl)); - } + pub fn type_is_defined(&self, name: &str) -> bool { + self.name_to_id.contains_key(name) || self.functions.contains_key(name) } - pub fn get_type_decl(&self, name: &str) -> Option<&Declaration> { - self.type_decls.get(name) + pub fn get_type_id(&self, name: &str) -> Option { + self.name_to_id.get(name).copied() } - pub fn get_type(&self, name: &str) -> Option<&Type> { - self.type_env.get(name) + pub fn get_type(&self, name: &str) -> Option { + match self.name_to_id.get(name) { + Some(id) => match self.type_defs.get(id) { + Some(TypeDef::Struct(_)) => Some(Type::Struct(*id)), + Some(TypeDef::Enum(_)) => Some(Type::Enum(*id)), + None => panic!("Type not found"), + }, + None => self + .functions + .get(name) + .map(|func| Type::Function(func.clone())), + } } - pub fn set_type(&mut self, name: String, ty: Type) { - self.type_env.insert(name, ty); + pub fn get_type_def(&self, name: &str) -> Option<&TypeDef> { + match self.name_to_id.get(name) { + Some(id) => self.type_defs.get(id), + None => None, + } } - pub fn function_decls(&self) -> impl Iterator { - self.type_decls.values().filter(|decl| decl.is_function()) + pub fn get_function_def(&self, name: &str) -> Option<&FunctionDef> { + self.functions.get(name).map(|v| &**v) } - pub fn process_env(&mut self, env: &Environment) { - for (name, _value) in env.symbols.iter() { - self.set_type(name.clone(), Type::Any); + pub fn analyze_type_def(&mut self, stmts: &[StatementNode]) -> Result<(), TypeError> { + // round 1, decl types + for stmt in stmts { + match &stmt.node { + Statement::Item(ItemStatement::Struct(StructItem { name, .. })) => { + self.decl_type( + name.clone(), + TypeDef::Struct(StructDef { + name: name.clone(), + fields: HashMap::new(), + }), + ) + .map_err(|err| err.with_span(stmt.span))?; + } + Statement::Item(ItemStatement::Enum(item)) => { + self.decl_type( + item.name.clone(), + TypeDef::Enum(EnumDef { + name: item.name.clone(), + variants: Vec::new(), + }), + ); + } + _ => {} + } } - } - pub fn analyze_type_decl(&mut self, stmts: &[StatementNode]) { + // round 2, resolve types for stmt in stmts { match &stmt.node { - Statement::Item(ItemStatement::Fn(func)) => { - let FunctionItem { - name, - params, - return_ty, - body, - } = func; + Statement::Item(ItemStatement::Function(func)) => {} + Statement::Item(ItemStatement::Struct(item)) => { + self.analyze_struct_item(item)?; + } + Statement::Item(ItemStatement::Enum(item)) => { + self.analyze_enum_item(item)?; + } + _ => { + continue; + } + } + } - let mut param_types = Vec::new(); + Ok(()) + } - for param in params { - let ty = param.ty.clone().map(|t| self.resolve_type_decl(&t)); + fn analyze_function_item(&mut self, item: &FunctionItem) -> Result<(), TypeError> { + let FunctionItem { + name, + params, + return_ty, + .. + } = item; - param_types.push((param.name.clone(), ty)); - } + let mut func = FunctionDef { + name: name.clone(), + params: Vec::new(), + return_type: return_ty + .as_ref() + .map(|ty| self.resolve_type(ty)) + .transpose()?, + }; - let return_ty = return_ty - .as_ref() - .map(|t| Box::new(self.resolve_type_decl(t))); + for param in params { + let param_type = param + .ty + .as_ref() + .map(|ty| self.resolve_type(&ty)) + .transpose()?; + func.params.push((param.name.clone(), param_type)); + } - let func_decl = FunctionDeclaration { - name: name.clone(), - params: param_types, - return_type: return_ty, - }; + self.functions.insert(name.clone(), Box::new(func)); - self.add_type_decl(name.clone(), Declaration::Function(func_decl)); - } - Statement::Item(ItemStatement::Struct(item)) => { - let StructItem { name, fields } = item; - let struct_decl = StructDeclaration { - name: name.clone(), - fields: fields - .iter() - .map(|field| (field.name.clone(), self.resolve_type_decl(&field.ty))) - .collect(), - }; - - self.add_type_decl(name.clone(), Declaration::Struct(struct_decl)); - } - Statement::Item(ItemStatement::Enum(EnumItem { .. })) => {} - _ => {} + Ok(()) + } + + fn analyze_struct_item(&mut self, item: &StructItem) -> Result { + let StructItem { name, fields } = item; + + let type_id = self.name_to_id.get(name).unwrap(); + + let fields: HashMap = fields + .iter() + .map(|field| { + self.resolve_type(&field.ty) + .map(|ty| (field.name.clone(), ty)) + }) + .collect::>()?; + + match self.type_defs.get_mut(type_id) { + Some(TypeDef::Struct(struct_def)) => { + // update the struct definition + struct_def.fields = fields; + } + _ => { + return Err(ErrKind::Message(format!("Type {} is not a struct", name)).into()); + } + } + + Ok(*type_id) + } + + fn analyze_enum_item(&mut self, item: &EnumItem) -> Result { + let EnumItem { name, variants } = item; + + let type_id = self.name_to_id.get(name).unwrap(); + + let mut enum_variants = Vec::new(); + + for variant in variants { + let EnumVariant { name, variant } = variant; + + let variant_type = variant + .as_ref() + .map(|ty| self.resolve_type(&ty)) + .transpose()?; + + enum_variants.push((name.to_string(), variant_type)); + } + + match self.type_defs.get_mut(type_id) { + Some(TypeDef::Enum(enum_def)) => { + // update the enum definition + enum_def.variants = enum_variants; } + _ => { + return Err(ErrKind::Message(format!("Type {} is not an enum", name)).into()); + } + } + + Ok(*type_id) + } + + fn decl_type(&mut self, name: String, ty: TypeDef) -> Result { + if self.name_to_id.contains_key(&name) { + return Err(ErrKind::DuplicateName(name).into()); } + + let id = self.next_id(); + self.type_defs.insert(id, ty); + self.name_to_id.insert(name, id); + + Ok(id) + } + + fn next_id(&mut self) -> TypeId { + let id = self.next_id; + self.next_id += 1; + TypeId(id) } - // 新增:递归解析类型声明 - fn resolve_type_decl_recursive(&mut self, type_expr: &TypeExpression) -> Type { + /// Try to analyze type expressions when possible, otherwise return Type::Unknown. + fn try_resolve_type(&self, type_expr: &TypeExpression) -> Type { match type_expr { TypeExpression::Any => Type::Any, TypeExpression::Boolean => Type::Boolean, @@ -110,55 +325,473 @@ impl TypeContext { TypeExpression::Float => Type::Float, TypeExpression::Char => Type::Char, TypeExpression::String => Type::String, - TypeExpression::Tuple(types) => { - let types = types.iter().map(|ty| self.resolve_type_decl(ty)).collect(); - Type::Tuple(types) - } - TypeExpression::Array(ty) => { - let ty = self.resolve_type_decl(ty); - Type::Array(Box::new(ty)) - } - TypeExpression::UserDefined(ty) => { - // 检查缓存中是否存在已解析的类型 - if let Some(cached_type) = self.resolved_types.get(ty) { - return cached_type.clone(); + TypeExpression::Array(_) => Type::Array, + TypeExpression::Tuple(_) => Type::Tuple, + TypeExpression::UserDefined(name) => match self.name_to_id.get(name).cloned() { + Some(id) => match self.type_defs.get(&id) { + Some(TypeDef::Struct(_)) => Type::Struct(id), + Some(TypeDef::Enum(_)) => Type::Enum(id), + _ => panic!("Invalid type"), + }, + None => Type::Unknown, + }, + _ => Type::Unknown, + } + } + + pub fn resolve_type(&self, type_expr: &TypeExpression) -> Result { + match self.try_resolve_type(type_expr) { + Type::Unknown => Err(ErrKind::UnresovledType(format!("{:?}", type_expr)).into()), + ty => Ok(ty), + } + } +} + +pub struct TypeChecker<'a> { + type_cx: &'a TypeContext, + current_function_return_type: Option, + symbols: SymbolTable, +} + +impl<'a> TypeChecker<'a> { + pub fn new(type_cx: &'a TypeContext) -> Self { + TypeChecker { + type_cx, + current_function_return_type: None, + symbols: SymbolTable::new(), + } + } + + pub fn check_program(&mut self, program: &Program, env: &Environment) -> Result<(), TypeError> { + // all env is any + for name in env.keys() { + self.symbols.insert(name.to_string(), Type::Any); + } + // insert function type + for (name, func) in self.type_cx.functions.iter() { + self.symbols + .insert(func.name.clone(), Type::Function(func.clone())); + } + + for item in &program.stmts { + self.check_statement(&item)?; + } + Ok(()) + } + + fn check_statement(&mut self, stmt: &StatementNode) -> Result<(), TypeError> { + match &stmt.node { + Statement::Let(let_stmt) => self.check_let_statement(let_stmt), + Statement::Block(block) => self.check_block_statement(block), + Statement::If(if_stmt) => self.check_if_statement(if_stmt), + Statement::While(while_stmt) => self.check_while_statement(while_stmt), + Statement::For(for_stmt) => self.check_for_statement(for_stmt), + Statement::Loop(loop_stmt) => self.check_loop_statement(loop_stmt), + Statement::Return(return_stmt) => self.check_return_statement(return_stmt), + Statement::Expression(expr) => self.analyze_expression(expr).map(|_| ()), + Statement::Empty => Ok(()), + Statement::Break => Ok(()), + Statement::Continue => Ok(()), + Statement::Item(item_stmt) => self.check_item_statement(item_stmt), + } + } + + // 新增方法:检查块语句 + fn check_block_statement(&mut self, block: &BlockStatement) -> Result<(), TypeError> { + self.symbols.enter_scope(); + + for stmt in &block.0 { + self.check_statement(stmt)?; + } + + self.symbols.leave_scope(); + Ok(()) + } + + // 新增方法:检查条件语句 + fn check_if_statement(&mut self, if_stmt: &IfStatement) -> Result<(), TypeError> { + let condition_type = self.analyze_expression(&if_stmt.condition)?; + + if condition_type != Type::Boolean && condition_type != Type::Any { + return Err(ErrKind::TypeMismatch { + expected: Type::Boolean, + actual: condition_type, + } + .with_span(if_stmt.condition.span)); + } + + self.check_block_statement(&if_stmt.then_branch)?; + if let Some(else_branch) = &if_stmt.else_branch { + self.check_block_statement(else_branch)?; + } + Ok(()) + } + + // 新增方法:检查循环语句 + fn check_while_statement(&mut self, while_stmt: &WhileStatement) -> Result<(), TypeError> { + let condition_type = self.analyze_expression(&while_stmt.condition)?; + + if condition_type != Type::Boolean && condition_type != Type::Any { + return Err(ErrKind::TypeMismatch { + expected: Type::Boolean, + actual: condition_type, + } + .with_span(while_stmt.condition.span)); + } + + self.check_block_statement(&while_stmt.body)?; + Ok(()) + } + + // 新增方法:检查 for 循环语句 + fn check_for_statement(&mut self, for_stmt: &ForStatement) -> Result<(), TypeError> { + self.analyze_expression(&for_stmt.iterable)?; + self.check_block_statement(&for_stmt.body)?; + Ok(()) + } + + // 新增方法:检查无限循环语句 + fn check_loop_statement(&mut self, loop_stmt: &LoopStatement) -> Result<(), TypeError> { + self.check_block_statement(&loop_stmt.body)?; + Ok(()) + } + + // 新增方法:检查返回语句 + fn check_return_statement(&mut self, return_stmt: &ReturnStatement) -> Result<(), TypeError> { + if let Some(expr) = &return_stmt.value { + let return_ty = self.analyze_expression(expr)?; + + if let Some(expected_ty) = &self.current_function_return_type { + if return_ty != *expected_ty { + return Err(ErrKind::TypeMismatch { + expected: expected_ty.clone(), + actual: return_ty, + } + .with_span(expr.span())); } + } + } + Ok(()) + } + + // 新增方法:检查项语句 + fn check_item_statement(&mut self, item_stmt: &ItemStatement) -> Result<(), TypeError> { + if let ItemStatement::Function(func) = item_stmt { + self.check_function_item(func)?; + } + + Ok(()) + } - // 检查是否已经存在于 type_env 中 - if let Some(ty) = self.get_type(ty) { - return ty.clone(); + fn check_function_item(&mut self, func_item: &FunctionItem) -> Result<(), TypeError> { + // new scope + let old_return_type = self.current_function_return_type.clone(); + self.symbols.enter_scope(); + + for param in &func_item.params { + let param_type = match param.ty.as_ref() { + Some(ty) => self.type_cx.resolve_type(&ty)?, + None => Type::Any, + }; + + self.symbols.insert(param.name.clone(), param_type); + } + + self.current_function_return_type = func_item + .return_ty + .as_ref() + .map(|ty| self.type_cx.resolve_type(ty)) + .transpose()?; + + self.check_block_statement(&func_item.body)?; + + // restore + self.symbols.leave_scope(); + self.current_function_return_type = old_return_type; + + Ok(()) + } + + fn check_let_statement(&mut self, let_stmt: &LetStatement) -> Result<(), TypeError> { + let ty = let_stmt + .ty + .as_ref() + .map(|ty| self.type_cx.resolve_type(&ty)) + .transpose()?; + + if let Some(expr) = let_stmt.value.as_ref() { + let value_type = self.analyze_expression(expr)?; + if ty.is_some() && (ty.as_ref() != Some(&value_type)) { + return Err(ErrKind::TypeMismatch { + expected: ty.unwrap(), + actual: value_type, } + .with_span(expr.span)); + } + } - // 解析类型声明 - if let Some(decl) = self.get_type_decl(ty) { - let resolved_type = match decl { - Declaration::Function(func_decl) => { - Type::Decl(Declaration::Function(func_decl.clone())) - } - Declaration::Struct(struct_decl) => { - Type::Decl(Declaration::Struct(struct_decl.clone())) - } - Declaration::Enum(enum_decl) => { - Type::Decl(Declaration::Enum(enum_decl.clone())) + Ok(()) + } + + fn analyze_expression(&mut self, expr: &ExpressionNode) -> Result { + let ret = match &expr.node { + Expression::Literal(lit) => self.analyze_literal(lit), + Expression::Identifier(id) => self.analyze_identifier(id), + Expression::Binary(bin) => self.analyze_binary(bin), + Expression::Prefix(prefix) => self.analyze_prefix(prefix), + Expression::Call(call) => self.analyze_call(call), + Expression::Environment(env) => Ok(Type::String), + Expression::Path(path) => self.analyze_path(path), + Expression::Tuple(tuple) => self.analyze_tuple(tuple), + Expression::Array(arr) => self.analyze_array(arr), + Expression::Map(map) => Ok(Type::Any), // 暂定Map类型为Any + Expression::Closure(closure) => self.analyze_closure(closure), + Expression::Range(range) => Ok(Type::Any), // 暂定Range类型为Any + Expression::Slice(slice) => self.analyze_slice(slice), + Expression::Assign(assign) => self.analyze_assign(assign), + Expression::IndexGet(index) => self.analyze_index_get(index), + Expression::IndexSet(index) => self.analyze_index_set(index), + Expression::PropertyGet(prop) => self.analyze_property_get(prop), + Expression::PropertySet(prop) => self.analyze_property_set(prop), + Expression::CallMethod(call) => self.analyze_call_method(call), + Expression::StructExpr(struct_) => self.analyze_struct_expr(struct_), + Expression::Await(expr) => self.analyze_expression(expr), + Expression::Try(expr) => self.analyze_expression(expr), + // _ => Err(ErrKind::Message(format!("Unsupported expression: {:?}", expr.node)).into()), + }; + + ret.map_err(|err| { + if !err.span.is_empty() { + return err.with_span(expr.span); + } else { + err + } + }) + } + + fn analyze_path(&self, path: &PathExpression) -> Result { + // 路径表达式类型解析逻辑 + Ok(Type::Any) // 暂定返回Any类型 + } + + fn analyze_tuple(&mut self, tuple: &TupleExpression) -> Result { + // 元组类型解析逻辑 + Ok(Type::Tuple) + } + + fn analyze_array(&mut self, arr: &ArrayExpression) -> Result { + // 数组类型解析逻辑 + Ok(Type::Array) + } + + fn analyze_closure(&mut self, closure: &ClosureExpression) -> Result { + // // 闭包类型解析逻辑 + // Ok(Type::Function(Box::new(FunctionDef { + // name: "".to_string(), + // params: vec![], + // return_type: None, + // }))) + + Ok(Type::Any) // 临时返回 Any 类型 + } + + fn analyze_slice(&mut self, slice: &SliceExpression) -> Result { + // 切片类型解析逻辑 + Ok(Type::Array) + } + + fn analyze_assign(&mut self, assign: &AssignExpression) -> Result { + // 赋值表达式类型解析逻辑 + self.analyze_expression(&assign.value) + } + + fn analyze_index_get(&mut self, index: &IndexGetExpression) -> Result { + // 索引获取类型解析逻辑 + Ok(Type::Any) // 暂定返回Any类型 + } + + fn analyze_index_set(&mut self, index: &IndexSetExpression) -> Result { + // 索引设置类型解析逻辑 + self.analyze_expression(&index.value) + } + + fn analyze_property_get(&mut self, prop: &PropertyGetExpression) -> Result { + // 属性获取类型解析逻辑 + Ok(Type::Any) // 暂定返回Any类型 + } + + fn analyze_property_set(&mut self, prop: &PropertySetExpression) -> Result { + // 属性设置类型解析逻辑 + self.analyze_expression(&prop.value) + } + + fn analyze_call_method(&mut self, call: &CallMethodExpression) -> Result { + // 调用方法类型解析逻辑 + Ok(Type::Any) // 暂定返回Any类型 + } + + fn analyze_struct_expr(&mut self, struct_expr: &StructExpression) -> Result { + match self.type_cx.get_type_def(&struct_expr.name.node()) { + Some(TypeDef::Struct(struct_def)) => { + for field in &struct_expr.fields { + let field_type = self.analyze_expression(&field.value)?; + match struct_def.fields.get(&field.name.node) { + Some(expected_type) => { + if field_type != *expected_type { + return Err(TypeError::new( + field.value.span(), + ErrKind::Message(format!( + "Expected type {:?} for field {:?}, found {:?}", + expected_type, field.name, field_type + )), + )); + } } - }; + None => {} + } + } + + Ok(self.type_cx.get_type(&struct_expr.name.node()).unwrap()) + } + Some(ty) => Err(TypeError::new( + struct_expr.name.span(), + ErrKind::Message(format!("Expected struct type, found {:?}", ty)), + )), + None => Err(ErrKind::UnresovledType(struct_expr.name.node().clone()).into()), + } + } + + fn analyze_literal(&self, lit: &LiteralExpression) -> Result { + match lit { + LiteralExpression::Null => Ok(Type::Any), + LiteralExpression::Boolean(_) => Ok(Type::Boolean), + LiteralExpression::Integer(_) => Ok(Type::Integer), + LiteralExpression::Float(_) => Ok(Type::Float), + LiteralExpression::Char(_) => Ok(Type::Char), + LiteralExpression::String(_) => Ok(Type::String), + } + } + + fn analyze_identifier(&self, id: &IdentifierExpression) -> Result { + match self.symbols.lookup(&id.0) { + Some(ty) => Ok(ty.clone()), + None => Err(ErrKind::Message(format!("Undefined identifier: {}", id.0)).into()), + } + } + + fn analyze_binary(&mut self, bin: &BinaryExpression) -> Result { + let lhs_type = self.analyze_expression(&bin.lhs)?; + let rhs_type = self.analyze_expression(&bin.rhs)?; - // 更新缓存 - self.resolved_types - .insert(ty.clone(), resolved_type.clone()); - self.set_type(ty.clone(), resolved_type.clone()); - resolved_type + // Handle Type::Any as compatible with any type + if lhs_type == Type::Any || rhs_type == Type::Any { + return Ok(Type::Any); + } + + match bin.op { + BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem => { + if lhs_type.is_numeric() && lhs_type == rhs_type { + Ok(lhs_type) // Simplified for demonstration + } else { + Err(ErrKind::Message(format!( + "Type mismatch in binary operation: {:?} and {:?}", + lhs_type, rhs_type + )) + .into()) + } + } + BinOp::Equal | BinOp::NotEqual => { + if lhs_type == rhs_type { + Ok(Type::Boolean) + } else { + Err(ErrKind::Message(format!( + "Type mismatch in comparison: {:?} and {:?}", + lhs_type, rhs_type + )) + .into()) + } + } + _ => Err(ErrKind::Message(format!("Unsupported binary operator: {:?}", bin.op)).into()), + } + } + + fn analyze_prefix(&mut self, prefix: &PrefixExpression) -> Result { + let rhs_type = self.analyze_expression(&prefix.rhs)?; + + // Handle Type::Any as compatible with any type + if rhs_type == Type::Any { + return Ok(Type::Any); + } + + match prefix.op { + PrefixOp::Neg => { + if rhs_type.is_numeric() { + Ok(rhs_type) } else { - // 如果无法解析,则返回 UserDefined - Type::UserDefined(ty.clone()) + Err(ErrKind::Message(format!( + "Cannot apply negation to non-numeric type: {:?}", + rhs_type + )) + .into()) + } + } + PrefixOp::Not => { + if rhs_type == Type::Boolean { + Ok(Type::Boolean) + } else { + Err(ErrKind::Message(format!( + "Cannot apply logical NOT to non-boolean type: {:?}", + rhs_type + )) + .into()) } } - _ => Type::Any, } } - // 新增:对外暴露的解析函数 - pub fn resolve_type_decl(&mut self, type_expr: &TypeExpression) -> Type { - self.resolve_type_decl_recursive(type_expr) + fn analyze_call(&mut self, call: &CallExpression) -> Result { + let func_type = self.analyze_expression(&call.func)?; + + // Handle Type::Any as compatible with any type + if func_type == Type::Any { + return Ok(Type::Any); + } + + match func_type { + Type::Function(func_def) => { + // Check the number of arguments + if func_def.params.len() != call.args.len() { + return Err(ErrKind::Message(format!( + "Expected {} arguments, but got {}", + func_def.params.len(), + call.args.len() + )) + .into()); + } + + // Check each argument type + for (i, (param_name, param_type)) in func_def.params.iter().enumerate() { + let arg_type = self.analyze_expression(&call.args[i])?; + if let Some(expected_type) = param_type { + // Handle Type::Any as compatible with any type + if *expected_type != Type::Any && arg_type != *expected_type { + return Err(ErrKind::Message(format!( + "Argument {} has type {:?}, but expected {:?}", + i + 1, + arg_type, + expected_type + )) + .into()); + } + } + } + + // Return the function's return type + Ok(func_def.return_type.unwrap_or(Type::Any)) + } + _ => Err( + ErrKind::Message(format!("Cannot call non-function type: {:?}", func_type)).into(), + ), + } } } diff --git a/src/runtime/environment.rs b/src/runtime/environment.rs index e133e51..675dad1 100644 --- a/src/runtime/environment.rs +++ b/src/runtime/environment.rs @@ -26,6 +26,14 @@ impl Environment { } } + pub fn keys(&self) -> impl Iterator { + self.symbols.keys() + } + + pub fn values(&self) -> impl Iterator { + self.symbols.values() + } + pub fn with_variable(mut self, name: impl ToString, value: T) -> Self { self.insert(name, value); self From 2817d92796ee3ea3f427d39374aec4524458ad54 Mon Sep 17 00:00:00 2001 From: zzzdong Date: Sun, 1 Jun 2025 13:31:54 +0800 Subject: [PATCH 2/2] feat: split type check --- examples/scripting.rs | 28 ++-- src/compiler.rs | 231 ++++++++++++++-------------- src/compiler/ast.rs | 1 - src/compiler/ast/syntax.rs | 43 +++++- src/compiler/ast/walker.rs | 298 ------------------------------------- src/compiler/lowering.rs | 150 +++++++++---------- src/compiler/parser.rs | 211 +++++++++----------------- src/compiler/semantic.rs | 190 +++++++++++++++-------- src/compiler/symbol.rs | 16 ++ src/compiler/typing.rs | 260 +++++++++++++++++++------------- src/error.rs | 10 +- src/runtime/object/map.rs | 4 +- tests/test_embed.rs | 4 +- 13 files changed, 620 insertions(+), 826 deletions(-) delete mode 100644 src/compiler/ast/walker.rs diff --git a/examples/scripting.rs b/examples/scripting.rs index 02610e5..56bef81 100644 --- a/examples/scripting.rs +++ b/examples/scripting.rs @@ -136,11 +136,9 @@ mod scripting { Ok(None) } - None => { - Err(evalit::RuntimeError::invalid_argument::( - 0, &args[0], - )) - } + None => Err(evalit::RuntimeError::invalid_argument::( + 0, &args[0], + )), } } @@ -161,11 +159,9 @@ mod scripting { Ok(None) } - _ => { - Err(evalit::RuntimeError::invalid_argument::( - 0, &args[0], - )) - } + _ => Err(evalit::RuntimeError::invalid_argument::( + 0, &args[0], + )), } } @@ -180,11 +176,9 @@ mod scripting { Ok(None) } - None => { - Err(evalit::RuntimeError::invalid_argument::( - 0, &args[0], - )) - } + None => Err(evalit::RuntimeError::invalid_argument::( + 0, &args[0], + )), } } @@ -195,9 +189,7 @@ mod scripting { Ok(Some(ValueRef::new(self.remote_addr.clone()))) } - _ => { - Err(evalit::RuntimeError::missing_method::(method)) - } + _ => Err(evalit::RuntimeError::missing_method::(method)), } } } diff --git a/src/compiler.rs b/src/compiler.rs index 6311925..381c930 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -9,149 +9,135 @@ mod symbol; mod typing; use std::collections::HashMap; -use std::path::Path; use std::sync::Arc; use ir::builder::{InstBuilder, IrBuilder}; use ir::instruction::IrUnit; use log::debug; -use typing::{Type, TypeChecker, TypeContext, TypeError}; +use typing::{TypeChecker, TypeContext, TypeError}; use crate::Environment; use crate::bytecode::{Module, Register}; +use crate::compiler::ast::syntax::Span; +use crate::compiler::symbol::SymbolTable; use parser::ParseError; use codegen::Codegen; -use lowering::{ASTLower, SymbolTable}; -use semantic::SemanticAnalyzer; +use lowering::ASTLower; +use semantic::{SemanticAnalyzer, SemanticError}; -pub fn compile(script: &str, env: &crate::Environment) -> Result, CompileError> { +pub fn compile<'i>( + script: &'i str, + env: &crate::Environment, +) -> Result, CompileError<'i>> { Compiler::new().compile(script, env) } -#[derive(Debug)] -pub enum CompileError { - Io(std::io::Error), - Parse(ParseError), - Type(TypeError), - Semantics(String), - UndefinedVariable { - name: String, - }, - UnknownType { - name: String, - }, - TypeMismatch { - expected: Box, - actual: Box, - span: Span, - }, - TypeInference(String), - TypeCheck(String), - ArgumentCountMismatch { - expected: usize, - actual: usize, - }, - NotCallable { - ty: Type, - span: Span, - }, - Unreachable, - BreakOutsideLoop { - span: Span, - }, - ContinueOutsideLoop { - span: Span, - }, - ReturnOutsideFunction { - span: Span, - }, - InvalidOperation { - message: String, - }, -} - -impl CompileError { - pub fn type_mismatch(expected: Type, actual: Type, span: Span) -> Self { - CompileError::TypeMismatch { +#[derive(Debug, Clone, Copy)] +struct LineCol { + line: usize, + col: usize, +} + +impl From<(usize, usize)> for LineCol { + fn from(value: (usize, usize)) -> Self { + Self { + line: value.0, + col: value.1, + } + } +} + +#[derive(Debug, Clone)] +pub struct CompileError<'i> { + kind: ErrorKind, + span: &'i str, + line_col: LineCol, +} + +impl<'i> CompileError<'i> { + pub fn new(input: &'i str, error: ErrorKind) -> Self { + let (line_col, span) = match error.line_col(input) { + Some(line_col) => { + let span = error.span(); + (line_col, &input[span.start..span.end]) + } + None => (LineCol { line: 0, col: 0 }, &input[0..0]), + }; + Self { + kind: error, span, - expected: Box::new(expected), - actual: Box::new(actual), + line_col, } } } -impl From for CompileError { - fn from(error: std::io::Error) -> Self { - CompileError::Io(error) +impl<'i> std::fmt::Display for CompileError<'i> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Error on {}@{:?}, detail: {:?}", + self.span, self.line_col, self.kind + ) + } +} + +impl<'i> std::error::Error for CompileError<'i> {} + +#[derive(Debug, Clone)] +pub enum ErrorKind { + // Io(std::io::Error), + Parse(ParseError), + Type(TypeError), + Semantic(SemanticError), +} + +impl ErrorKind { + fn line_col(&self, input: &str) -> Option { + match self { + ErrorKind::Parse(ParseError { span, .. }) => span.line_col(input).map(Into::into), + ErrorKind::Type(TypeError { span, .. }) => span.line_col(input).map(Into::into), + ErrorKind::Semantic(SemanticError { span, .. }) => span.line_col(input).map(Into::into), + } + } + + fn span(&self) -> Span { + match self { + ErrorKind::Parse(ParseError { span, .. }) => *span, + ErrorKind::Type(TypeError { span, .. }) => *span, + ErrorKind::Semantic(SemanticError { span, .. }) => *span, + } } } -impl From for CompileError { +impl From for ErrorKind { fn from(error: ParseError) -> Self { - CompileError::Parse(error) + ErrorKind::Parse(error) } } -impl From for CompileError { +impl From for ErrorKind { fn from(error: TypeError) -> Self { - CompileError::Type(error) + ErrorKind::Type(error) } } -impl From for CompileError { - fn from(error: SemanticsError) -> Self { - CompileError::Semantics(error) +impl From for ErrorKind { + fn from(error: SemanticError) -> Self { + ErrorKind::Semantic(error) } } -impl std::fmt::Display for CompileError { +impl std::fmt::Display for ErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - CompileError::Io(error) => write!(f, "IO error: {error}"), - CompileError::Parse(error) => write!(f, "Parse error: {error}"), - CompileError::Type(error) => write!(f, "Type error: {error:?}"), - CompileError::Semantics(message) => write!(f, "Semantics error: {message}"), - CompileError::UndefinedVariable { name } => { - write!(f, "Undefined variable `{name}`") - } - CompileError::UnknownType { name } => { - write!(f, "Unknow type `{name}`") - } - CompileError::TypeMismatch { - expected, - actual, - span, - } => write!( - f, - "Type mismatch: expected `{expected:?}`, actual `{actual:?}` at {span:?}" - ), - CompileError::TypeInference(message) => write!(f, "Type inference error: {message}"), - CompileError::TypeCheck(message) => write!(f, "Type check error: {message}"), - CompileError::ArgumentCountMismatch { expected, actual } => write!( - f, - "Argument count mismatch: expected {expected}, actual {actual}" - ), - CompileError::NotCallable { ty, span } => { - write!(f, "Not callable: `{ty:?}` at {span:?}") - } - CompileError::Unreachable => write!(f, "Unreachable"), - CompileError::BreakOutsideLoop { span } => write!(f, "Break outside loop at {span:?}"), - CompileError::ContinueOutsideLoop { span } => { - write!(f, "Continue outside loop at {span:?}") - } - CompileError::ReturnOutsideFunction { span } => { - write!(f, "Return outside function at {span:?}") - } - CompileError::InvalidOperation { message } => { - write!(f, "Invalid operation, {message}") - } + ErrorKind::Parse(error) => write!(f, "Parse error: {error}"), + ErrorKind::Type(error) => write!(f, "Type error: {error:?}"), + ErrorKind::Semantic(error) => write!(f, "Semantic error: {error:?}"), } } } -impl std::error::Error for CompileError {} - pub struct FileId(usize); pub struct Context { @@ -169,12 +155,12 @@ impl Context { id } - pub fn add_file(&mut self, file: impl AsRef) -> Result { - let id = FileId(self.sources.len()); - let content = std::fs::read_to_string(file.as_ref())?; - self.sources.push(content); - Ok(id) - } + // pub fn add_file(&mut self, file: impl AsRef) -> Result { + // let id = FileId(self.sources.len()); + // let content = std::fs::read_to_string(file.as_ref())?; + // self.sources.push(content); + // Ok(id) + // } pub fn get_source(&self, file: FileId) -> Option<&str> { self.sources.get(file.0).map(|s| s.as_str()) @@ -194,24 +180,39 @@ impl Compiler { Self {} } - pub fn compile(&self, input: &str, env: &Environment) -> Result, CompileError> { + pub fn compile<'i>( + &self, + input: &'i str, + env: &Environment, + ) -> Result, CompileError<'i>> { + match self.compile_inner(input, env) { + Ok(module) => Ok(module), + Err(err) => Err(CompileError::new(input, err)), + } + } + + fn compile_inner(&self, input: &str, env: &Environment) -> Result, ErrorKind> { // 解析输入 - let mut ast = parser::parse_file(input)?; + let ast = parser::parse_file(input)?; debug!("AST: {ast:?}"); let mut type_cx = TypeContext::new(); - type_cx.analyze_type_def(&ast.stmts)?; + type_cx.check_type_def(&ast.stmts)?; // 语义分析 - let mut checker = SemanticChecker::new(&mut type_cx); - checker.check_program(&mut ast, env)?; + let mut analyzer = SemanticAnalyzer::new(&type_cx); + analyzer.analyze_program(&ast, env)?; + + // 类型检查 + let mut checker = TypeChecker::new(&type_cx); + checker.check_program(&ast, env)?; // IR生成, AST -> IR let mut unit = IrUnit::new(); let builder: &mut dyn InstBuilder = &mut IrBuilder::new(&mut unit); let mut lower = ASTLower::new(builder, SymbolTable::new(), env, &type_cx); - lower.lower_program(ast)?; + lower.lower_program(ast); // code generation, IR -> bytecode let mut codegen = Codegen::new(&Register::general()); diff --git a/src/compiler/ast.rs b/src/compiler/ast.rs index 41fa5ba..4a39d2c 100644 --- a/src/compiler/ast.rs +++ b/src/compiler/ast.rs @@ -1,2 +1 @@ pub mod syntax; -pub mod walker; \ No newline at end of file diff --git a/src/compiler/ast/syntax.rs b/src/compiler/ast/syntax.rs index 9cb5238..94bc5aa 100644 --- a/src/compiler/ast/syntax.rs +++ b/src/compiler/ast/syntax.rs @@ -1,12 +1,9 @@ use std::{ - collections::HashMap, fmt, io::{self, Error, ErrorKind}, str::FromStr, }; -use pest::{RuleType, iterators::Pair}; - #[derive(Debug, Clone, PartialEq)] pub struct AstNode { pub node: T, @@ -39,7 +36,7 @@ impl AsMut for AstNode { } } -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Default)] pub struct Span { pub start: usize, pub end: usize, @@ -57,9 +54,33 @@ impl Span { } } + pub fn is_zero(&self) -> bool { + self.start == 0 && self.start == self.end + } + pub fn is_empty(&self) -> bool { self.start == self.end } + + pub fn line_col(&self, source: &str) -> Option<(usize, usize)> { + let mut offset = 0; + for (l, line) in source.lines().enumerate() { + if self.start < offset + line.len() { + let off = self.start - offset; + let mut line_off = 0; + for (c, char) in line.char_indices() { + if line_off == off { + return Some((l + 1, c)); + } + line_off += char.len_utf8(); + } + } + + offset += line.len(); + } + + None + } } pub type StatementNode = AstNode; @@ -266,7 +287,7 @@ pub struct SliceExpression { #[derive(Debug, Clone, PartialEq)] pub enum Pattern { Wildcard, - Identifier(String), + Identifier(IdentifierExpression), Literal(LiteralExpression), Tuple(Vec), } @@ -395,7 +416,17 @@ pub struct PostfixExpression { } #[derive(Debug, Clone, PartialEq)] -pub struct IdentifierExpression(pub String); +pub struct IdentifierExpression(pub AstNode); + +impl IdentifierExpression { + pub fn new(name: String, span: Span) -> Self { + IdentifierExpression(AstNode::new(name, span)) + } + + pub fn name(&self) -> &str { + self.0.node().as_str() + } +} #[derive(Debug, Clone, PartialEq)] pub enum LiteralExpression { diff --git a/src/compiler/ast/walker.rs b/src/compiler/ast/walker.rs deleted file mode 100644 index 40aa49d..0000000 --- a/src/compiler/ast/walker.rs +++ /dev/null @@ -1,298 +0,0 @@ -use super::syntax::*; - -pub trait Walker { - type Error; - type StatementResult: Default; - type ExpressionResult: Default; - - fn walk_program(&mut self, program: &Program) -> Result<(), Self::Error> { - for stmt in &program.stmts { - self.walk_statement(stmt)?; - } - Ok(()) - } - - fn walk_statement( - &mut self, - stmt: &StatementNode, - ) -> Result { - match &stmt.node { - Statement::Empty => self.walk_empty(), - Statement::Break => self.walk_break(), - Statement::Continue => self.walk_continue(), - Statement::Block(stmt) => self.walk_block_statement(stmt), - Statement::Item(stmt) => self.walk_item_statement(stmt), - Statement::Let(stmt) => self.walk_let_statement(stmt), - Statement::For(stmt) => self.walk_for_statement(stmt), - Statement::While(stmt) => self.walk_while_statement(stmt), - Statement::Loop(stmt) => self.walk_loop_statement(stmt), - Statement::If(stmt) => self.walk_if_statement(stmt), - Statement::Return(stmt) => self.walk_return_statement(stmt), - Statement::Expression(expr) => self.walk_expression(expr).map(|_| Default::default()), - } - } - - fn walk_expression( - &mut self, - expr: &ExpressionNode, - ) -> Result { - match &expr.node { - Expression::Literal(expr) => self.walk_literal_expression(expr), - Expression::Identifier(expr) => self.walk_identifier_expression(expr), - Expression::Environment(expr) => self.walk_environment_expression(expr), - Expression::Path(expr) => self.walk_path_expression(expr), - Expression::Tuple(expr) => self.walk_tuple_expression(expr), - Expression::Array(expr) => self.walk_array_expression(expr), - Expression::Map(expr) => self.walk_map_expression(expr), - Expression::Closure(expr) => self.walk_closure_expression(expr), - Expression::Range(expr) => self.walk_range_expression(expr), - Expression::Slice(expr) => self.walk_slice_expression(expr), - Expression::Assign(expr) => self.walk_assign_expression(expr), - Expression::Call(expr) => self.walk_call_expression(expr), - Expression::Try(expr) => self.walk_expression(expr), - Expression::Await(expr) => self.walk_expression(expr), - Expression::Prefix(expr) => self.walk_prefix_expression(expr), - Expression::Binary(expr) => self.walk_binary_expression(expr), - Expression::IndexGet(expr) => self.walk_index_get_expression(expr), - Expression::IndexSet(expr) => self.walk_index_set_expression(expr), - Expression::PropertyGet(expr) => self.walk_property_get_expression(expr), - Expression::PropertySet(expr) => self.walk_property_set_expression(expr), - Expression::CallMethod(expr) => self.walk_call_method_expression(expr), - Expression::StructExpr(expr) => self.walk_struct_expression(expr), - } - } - - fn walk_empty(&mut self) -> Result { - Ok(Default::default()) - } - - fn walk_break(&mut self) -> Result { - Ok(Default::default()) - } - - fn walk_continue(&mut self) -> Result { - Ok(Default::default()) - } - - fn walk_block_statement( - &mut self, - stmt: &BlockStatement, - ) -> Result { - Ok(Default::default()) - } - - fn walk_item_statement( - &mut self, - stmt: &ItemStatement, - ) -> Result { - match stmt { - ItemStatement::Function(item) => { - self.walk_function_item(item)?; - Ok(Default::default()) - } - ItemStatement::Struct(item) => { - self.walk_struct_item(item)?; - Ok(Default::default()) - } - ItemStatement::Enum(item) => { - self.walk_enum_item(item)?; - Ok(Default::default()) - } - } - } - - fn walk_function_item(&mut self, item: &FunctionItem) -> Result<(), Self::Error> { - Ok(Default::default()) - } - - fn walk_struct_item(&mut self, item: &StructItem) -> Result<(), Self::Error> { - Ok(Default::default()) - } - - fn walk_enum_item(&mut self, item: &EnumItem) -> Result<(), Self::Error> { - Ok(Default::default()) - } - - fn walk_let_statement( - &mut self, - stmt: &LetStatement, - ) -> Result { - Ok(Default::default()) - } - - fn walk_for_statement( - &mut self, - stmt: &ForStatement, - ) -> Result { - Ok(Default::default()) - } - - fn walk_while_statement( - &mut self, - stmt: &WhileStatement, - ) -> Result { - Ok(Default::default()) - } - - fn walk_loop_statement( - &mut self, - stmt: &LoopStatement, - ) -> Result { - Ok(Default::default()) - } - - fn walk_if_statement( - &mut self, - stmt: &IfStatement, - ) -> Result { - Ok(Default::default()) - } - - fn walk_return_statement( - &mut self, - stmt: &ReturnStatement, - ) -> Result { - Ok(Default::default()) - } - - // 表达式遍历方法 - fn walk_literal_expression( - &mut self, - _expr: &LiteralExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_identifier_expression( - &mut self, - _expr: &IdentifierExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_environment_expression( - &mut self, - _expr: &EnvironmentExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_path_expression( - &mut self, - _expr: &PathExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_tuple_expression( - &mut self, - _expr: &TupleExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_array_expression( - &mut self, - _expr: &ArrayExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_map_expression( - &mut self, - _expr: &MapExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_closure_expression( - &mut self, - _expr: &ClosureExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_range_expression( - &mut self, - _expr: &RangeExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_slice_expression( - &mut self, - _expr: &SliceExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_assign_expression( - &mut self, - _expr: &AssignExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_call_expression( - &mut self, - _expr: &CallExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_prefix_expression( - &mut self, - _expr: &PrefixExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_binary_expression( - &mut self, - _expr: &BinaryExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_index_get_expression( - &mut self, - _expr: &IndexGetExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_index_set_expression( - &mut self, - _expr: &IndexSetExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_property_get_expression( - &mut self, - _expr: &PropertyGetExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_property_set_expression( - &mut self, - _expr: &PropertySetExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_call_method_expression( - &mut self, - _expr: &CallMethodExpression, - ) -> Result { - Ok(Default::default()) - } - - fn walk_struct_expression( - &mut self, - _expr: &StructExpression, - ) -> Result { - Ok(Default::default()) - } -} diff --git a/src/compiler/lowering.rs b/src/compiler/lowering.rs index cee1ef4..cecac06 100644 --- a/src/compiler/lowering.rs +++ b/src/compiler/lowering.rs @@ -1,10 +1,9 @@ use std::collections::HashMap; -use std::{cell::RefCell, collections::BTreeMap, rc::Rc}; -use super::CompileError; use super::ast::syntax::*; use super::ir::{builder::*, instruction::*}; use super::typing::{TypeContext, TypeDef}; +use crate::compiler::symbol::SymbolTable; use crate::compiler::typing::{FunctionDef, StructDef}; use crate::{ Environment, @@ -29,7 +28,7 @@ impl LoopContext { pub struct ASTLower<'a> { builder: &'a mut dyn InstBuilder, env: &'a Environment, - symbols: SymbolTable, + symbols: SymbolTable, loop_contexts: Vec, type_cx: &'a TypeContext, } @@ -37,7 +36,7 @@ pub struct ASTLower<'a> { impl<'a> ASTLower<'a> { pub fn new( builder: &'a mut dyn InstBuilder, - symbols: SymbolTable, + symbols: SymbolTable, env: &'a Environment, type_cx: &'a TypeContext, ) -> Self { @@ -50,7 +49,7 @@ impl<'a> ASTLower<'a> { } } - pub fn lower_program(&mut self, prog: Program) -> Result { + pub fn lower_program(&mut self, prog: Program) -> IrUnit { let mut unit = IrUnit::new(); let builder: &mut dyn InstBuilder = &mut IrBuilder::new(&mut unit); @@ -86,7 +85,7 @@ impl<'a> ASTLower<'a> { // FIXME: This is a hack to make block not empty. self.builder.make_halt(); - Ok(unit) + unit } fn lower_statement(&mut self, statement: StatementNode) { @@ -139,7 +138,7 @@ impl<'a> ASTLower<'a> { self.builder.assign(dst, value); } - self.symbols.define(&name, Variable::new(dst)); + self.symbols.insert(&name, Variable::new(dst)); } fn lower_item_stmt(&mut self, item: ItemStatement) { @@ -186,7 +185,7 @@ impl<'a> ASTLower<'a> { Pattern::Identifier(ident) => { let dst = self.builder.alloc(); self.builder.assign(dst, value); - self.symbols.define(&ident, Variable::new(dst)); + self.symbols.insert(ident.name(), Variable::new(dst)); } Pattern::Tuple(pats) => { @@ -232,19 +231,19 @@ impl<'a> ASTLower<'a> { // loop body, get next value self.builder.switch_to_block(loop_body); - let new_symbols = self.symbols.new_scope(); - let old_symbols = std::mem::replace(&mut self.symbols, new_symbols); + + self.symbols.enter_scope(); let next = self.builder.call_property(next, "unwrap", vec![]); self.lower_pattern(pat, next); self.lower_block(body); - self.symbols = old_symbols; + self.symbols.leave_scope(); self.builder.br(loop_header); // done loop - self.level_loop_context(); + self.leave_loop_context(); self.builder.switch_to_block(after_blk); } @@ -258,16 +257,13 @@ impl<'a> ASTLower<'a> { self.builder.br(loop_body); self.builder.switch_to_block(loop_body); - let new_symbols = self.symbols.new_scope(); - let old_symbols = std::mem::replace(&mut self.symbols, new_symbols); self.lower_block(body); self.builder.br(loop_body); // done loop - self.level_loop_context(); - self.symbols = old_symbols; + self.leave_loop_context(); self.builder.switch_to_block(after_blk); } @@ -287,15 +283,11 @@ impl<'a> ASTLower<'a> { self.builder.br_if(cond, body_blk, after_blk); self.builder.switch_to_block(body_blk); - let new_symbols = self.symbols.new_scope(); - let old_symbols = std::mem::replace(&mut self.symbols, new_symbols); self.lower_block(body); self.builder.br(cond_blk); - self.level_loop_context(); - self.symbols = old_symbols; self.builder.switch_to_block(after_blk); } @@ -308,12 +300,13 @@ impl<'a> ASTLower<'a> { } fn lower_block(&mut self, block: BlockStatement) { - let new_symbols = self.symbols.new_scope(); - let old_symbols = std::mem::replace(&mut self.symbols, new_symbols); + self.symbols.enter_scope(); + for statement in block.0 { self.lower_statement(statement); } - self.symbols = old_symbols; + + self.symbols.leave_scope(); } fn lower_function_item(&mut self, fn_item: FunctionItem) -> Value { @@ -322,7 +315,7 @@ impl<'a> ASTLower<'a> { } = fn_item; let value = self.lower_function(Some(name.to_string()), params, body); - self.symbols.define(name, Variable::new(value)); + self.symbols.insert(name, Variable::new(value)); value } @@ -345,7 +338,7 @@ impl<'a> ASTLower<'a> { let mut func = IrFunction::new(func_id, func_sig); - let symbols = self.symbols.new_scope(); + let symbols = self.symbols.clone(); let mut func_builder = FunctionBuilder::new(self.builder.module_mut(), &mut func); @@ -360,7 +353,7 @@ impl<'a> ASTLower<'a> { func_lower .symbols - .define(param.name.as_str(), Variable::new(arg)); + .insert(param.name.as_str(), Variable::new(arg)); } func_lower.lower_block(body); @@ -464,7 +457,7 @@ impl<'a> ASTLower<'a> { let decl = self .type_cx - .get_type_def(&name.node()) + .get_type_def(name.node()) .expect("struct not found"); if let TypeDef::Struct(StructDef { fields: decl_fields, @@ -512,19 +505,20 @@ impl<'a> ASTLower<'a> { .collect(); match func.node { - Expression::Identifier(IdentifierExpression(ref ident)) => { - match self.builder.module().find_function(ident) { + Expression::Identifier(ident) => { + match self.builder.module().find_function(ident.name()) { Some(func) => self.builder.call_function(func.id, args), - None => match self.symbols.get(ident) { + None => match self.symbols.lookup(ident.name()) { Some(var) => self.builder.make_call(var.0, args), - None => match self.env.get(ident) { + None => match self.env.get(ident.name()) { Some(EnvVariable::Function(_)) => { - let callable = - self.builder.load_external_variable(ident.to_string()); + let callable = self + .builder + .load_external_variable(ident.name().to_string()); self.builder.make_call_native(callable, args) } _ => { - panic!("unknown identifier: {ident}"); + panic!("unknown identifier: {}", ident.name()); } }, }, @@ -624,15 +618,15 @@ impl<'a> ASTLower<'a> { } fn lower_identifier(&mut self, identifier: IdentifierExpression) -> Value { - match self.symbols.get(&identifier.0) { - Some(Variable(addr)) => addr, + match self.symbols.lookup(identifier.name()) { + Some(Variable(addr)) => *addr, None => { - if let Some(_env) = self.env.get(&identifier.0) { + if let Some(_env) = self.env.get(identifier.name()) { return self .builder - .load_external_variable(identifier.0.to_string()); + .load_external_variable(identifier.name().to_string()); } - panic!("Undefined identifier: {}", identifier.0) + panic!("Undefined identifier: {}", identifier.name()) } } } @@ -707,7 +701,7 @@ impl<'a> ASTLower<'a> { .push(LoopContext::new(break_point, continue_point)); } - fn level_loop_context(&mut self) { + fn leave_loop_context(&mut self) { self.loop_contexts.pop().expect("not in loop context"); } } @@ -721,41 +715,41 @@ impl Variable { } } -#[derive(Debug, Clone)] -pub struct SymbolNode { - parent: Option, - symbols: BTreeMap, -} - -#[derive(Debug, Clone)] -pub struct SymbolTable(Rc>); - -impl SymbolTable { - pub fn new() -> Self { - SymbolTable(Rc::new(RefCell::new(SymbolNode { - parent: None, - symbols: BTreeMap::new(), - }))) - } - - fn get(&self, name: &str) -> Option { - if let Some(value) = self.0.borrow().symbols.get(name) { - return Some(*value); - } - if let Some(parent) = &self.0.borrow().parent { - return parent.get(name); - } - None - } - - fn define(&mut self, name: impl Into, value: Variable) { - self.0.borrow_mut().symbols.insert(name.into(), value); - } - - fn new_scope(&self) -> SymbolTable { - SymbolTable(Rc::new(RefCell::new(SymbolNode { - parent: Some(self.clone()), - symbols: BTreeMap::new(), - }))) - } -} +// #[derive(Debug, Clone)] +// pub struct SymbolNode { +// parent: Option, +// symbols: BTreeMap, +// } + +// #[derive(Debug, Clone)] +// pub struct SymbolTable(Rc>); + +// impl SymbolTable { +// pub fn new() -> Self { +// SymbolTable(Rc::new(RefCell::new(SymbolNode { +// parent: None, +// symbols: BTreeMap::new(), +// }))) +// } + +// fn get(&self, name: &str) -> Option { +// if let Some(value) = self.0.borrow().symbols.get(name) { +// return Some(*value); +// } +// if let Some(parent) = &self.0.borrow().parent { +// return parent.get(name); +// } +// None +// } + +// fn insert(&mut self, name: impl Into, value: Variable) { +// self.0.borrow_mut().symbols.insert(name.into(), value); +// } + +// fn new_scope(&self) -> SymbolTable { +// SymbolTable(Rc::new(RefCell::new(SymbolNode { +// parent: Some(self.clone()), +// symbols: BTreeMap::new(), +// }))) +// } +// } diff --git a/src/compiler/parser.rs b/src/compiler/parser.rs index d0b3321..04ca663 100644 --- a/src/compiler/parser.rs +++ b/src/compiler/parser.rs @@ -8,29 +8,44 @@ use pest::{ use super::ast::syntax::*; -#[derive(Debug)] -pub struct ParseError(Box>); +#[derive(Debug, Clone)] +pub struct ParseError { + pub error: Box>, + pub span: Span, +} impl ParseError { pub fn with_message(span: pest::Span<'_>, message: impl ToString) -> Self { - Self(Box::new(pest::error::Error::new_from_span( + let error = pest::error::Error::::new_from_span( pest::error::ErrorVariant::CustomError { message: message.to_string(), }, span, - ))) + ); + ParseError { + error: Box::new(error), + span: Span::new(span.start(), span.end()), + } } } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) + write!(f, "{}", self.error) } } impl From> for ParseError { fn from(e: pest::error::Error) -> Self { - Self(Box::new(e)) + let span = match e.location { + pest::error::InputLocation::Pos(pos) => Span::new(pos, pos + 1), + pest::error::InputLocation::Span(span) => Span::new(span.0, span.1), + }; + + Self { + error: Box::new(e), + span, + } } } @@ -357,7 +372,7 @@ fn parse_pattern(pair: Pair) -> Result { match pat.as_rule() { Rule::wildcard_pattern => Ok(Pattern::Wildcard), - Rule::identifier => Ok(Pattern::Identifier(pat.as_str().to_string())), + Rule::identifier => Ok(Pattern::Identifier(parse_identifier(pat)?)), Rule::literal => Ok(Pattern::Literal(parse_literal(pat)?)), Rule::tuple_pattern => { let mut tuple = Vec::new(); @@ -764,7 +779,10 @@ fn parse_identifier(pair: Pair) -> Result { let mut pairs = pair.into_inner(); let name = pairs.next().unwrap(); - Ok(IdentifierExpression(name.as_str().to_string())) + Ok(IdentifierExpression(AstNode::new( + name.as_str().to_string(), + Span::from_pair(&name), + ))) } fn unescape_string(s: &str) -> String { @@ -927,6 +945,14 @@ fn parse_path_segment(pair: Pair) -> Result { mod test { use super::*; + fn check_identifier_expression(input: &ExpressionNode, expected: &str) { + if let Expression::Identifier(identifier) = input.node() { + assert_eq!(identifier.name(), expected); + } else { + panic!("Expected Identifier expression"); + } + } + #[test] fn test_literal_expression() { let input = r#"1"#; @@ -951,34 +977,22 @@ mod test { let input = r#"foo"#; let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); - assert_eq!( - expression.node, - Expression::Identifier(IdentifierExpression("foo".to_string())) - ); + check_identifier_expression(&expression, "foo"); let input = r#"foo_bar"#; let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); - assert_eq!( - expression.node, - Expression::Identifier(IdentifierExpression("foo_bar".to_string())) - ); + check_identifier_expression(&expression, "foo_bar"); let input = r#"_foo"#; let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); - assert_eq!( - expression.node, - Expression::Identifier(IdentifierExpression("_foo".to_string())) - ); + check_identifier_expression(&expression, "_foo"); let input = r#"_"#; let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); - assert_eq!( - expression.node, - Expression::Identifier(IdentifierExpression("_".to_string())) - ); + check_identifier_expression(&expression, "_"); } #[test] @@ -1085,10 +1099,7 @@ mod test { if let Some(value) = &return_stmt.value { if let Expression::Binary(binary) = &value.node { assert_eq!(binary.op, BinOp::Add); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("x".to_string())) - ); + check_identifier_expression(&binary.lhs, "x"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -1113,10 +1124,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::IndexGet(index) = expression.node { - assert_eq!( - index.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&index.object, "a"); assert_eq!( index.index.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -1133,10 +1141,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::Range); assert_eq!( @@ -1156,10 +1161,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::RangeInclusive); assert_eq!( @@ -1179,10 +1181,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::Range); assert!(range.begin.is_none()); @@ -1196,10 +1195,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::Range); assert_eq!( @@ -1216,10 +1212,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::RangeInclusive); assert!(range.begin.is_none()); @@ -1238,10 +1231,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Call(call) = expression.node { - assert_eq!( - call.func.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&call.func, "a"); assert_eq!(call.args.len(), 3); assert_eq!( call.args[0].node, @@ -1263,10 +1253,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Call(call) = expression.node { - assert_eq!( - call.func.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&call.func, "a"); assert_eq!(call.args.len(), 0); } else { panic!("Expected call expression"); @@ -1279,10 +1266,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Try(expr) = expression.node { - assert_eq!( - expr.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&expr, "a"); } else { panic!("Expected try expression"); } @@ -1294,10 +1278,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Await(expr) = expression.node { - assert_eq!( - expr.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&expr, "a"); } else { panic!("Expected await expression"); } @@ -1323,10 +1304,7 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Prefix(prefix) = expression.node { assert_eq!(prefix.op, PrefixOp::Not); - assert_eq!( - prefix.rhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&prefix.rhs, "a"); } else { panic!("Expected prefix expression"); } @@ -1540,7 +1518,7 @@ mod test { // 检查第二个变体(值变体) assert_eq!( - enum_item.variants[0], + enum_item.variants[1], EnumVariant { name: "BB".to_string(), variant: Some(TypeExpression::Integer) @@ -1596,14 +1574,8 @@ mod test { if let Some(value) = &return_stmt.value { if let Expression::Binary(binary) = &value.node { assert_eq!(binary.op, BinOp::Add); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); - assert_eq!( - binary.rhs.node, - Expression::Identifier(IdentifierExpression("b".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); + check_identifier_expression(&binary.rhs, "b"); } else { panic!("Expected binary expression in return value"); } @@ -1678,10 +1650,7 @@ mod test { // 检查条件 if let Expression::Binary(binary) = while_stmt.condition.node { assert_eq!(binary.op, BinOp::Equal); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -1704,7 +1673,12 @@ mod test { let statement = parse_statement_input(input).unwrap(); if let Statement::For(for_stmt) = statement.node { // 检查模式 - assert_eq!(for_stmt.pat, Pattern::Identifier("i".to_string())); + + if let Pattern::Identifier(ident) = for_stmt.pat { + assert_eq!(ident.name(), "i"); + } else { + panic!("Expected identifier pattern in for statement"); + } // 检查迭代器表达式 if let Expression::Binary(binary) = for_stmt.iterable.node { @@ -1737,10 +1711,7 @@ mod test { // 检查条件 if let Expression::Binary(binary) = if_stmt.condition.node { assert_eq!(binary.op, BinOp::Equal); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -1777,10 +1748,7 @@ mod test { // 检查条件 if let Expression::Binary(binary) = if_stmt.condition.node { assert_eq!(binary.op, BinOp::Equal); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(2)) @@ -1837,10 +1805,7 @@ mod test { // 检查左侧条件 (a > 0) if let Expression::Binary(left_binary) = binary.lhs.node { assert_eq!(left_binary.op, BinOp::Greater); - assert_eq!( - left_binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&left_binary.lhs, "a"); assert_eq!( left_binary.rhs.node, Expression::Literal(LiteralExpression::Integer(0)) @@ -1852,10 +1817,7 @@ mod test { // 检查右侧条件 (b < 10) if let Expression::Binary(right_binary) = binary.rhs.node { assert_eq!(right_binary.op, BinOp::Less); - assert_eq!( - right_binary.lhs.node, - Expression::Identifier(IdentifierExpression("b".to_string())) - ); + check_identifier_expression(&right_binary.lhs, "b"); assert_eq!( right_binary.rhs.node, Expression::Literal(LiteralExpression::Integer(10)) @@ -1945,10 +1907,7 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::PropertyGet(property_get) = expression.node { assert_eq!(property_get.property, "b"); - assert_eq!( - property_get.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&property_get.object, "a"); } else { panic!("Expected property get expression"); } @@ -1960,10 +1919,7 @@ mod test { assert_eq!(property_get.property, "c"); if let Expression::PropertyGet(inner_property_get) = property_get.object.node { assert_eq!(inner_property_get.property, "b"); - assert_eq!( - inner_property_get.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&inner_property_get.object, "a"); } else { panic!("Expected inner member expression"); } @@ -1979,10 +1935,7 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::PropertySet(property_set) = expression.node { assert_eq!(property_set.property, "b"); - assert_eq!( - property_set.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&property_set.object, "a"); assert_eq!( property_set.value.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -2004,10 +1957,7 @@ mod test { value, }) = expression.node { - assert_eq!( - object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&object, "a"); assert_eq!( index.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -2029,10 +1979,7 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::CallMethod(call_method) = expression.node { assert_eq!(call_method.method, "b"); - assert_eq!( - call_method.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&call_method.object, "a"); assert_eq!(call_method.args.len(), 2); assert_eq!( call_method.args[0].node, @@ -2055,18 +2002,12 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::PropertySet(property_set) = expression.node { assert_eq!(property_set.property, "b"); - assert_eq!( - property_set.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&property_set.object, "a"); if let Expression::Binary(binary) = property_set.value.node { assert_eq!(binary.op, BinOp::Add); if let Expression::PropertyGet(property_get) = binary.lhs.node { assert_eq!(property_get.property, "b"); - assert_eq!( - property_get.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&property_get.object, "a"); } else { panic!("Expected PropertyGet on lhs"); } @@ -2086,16 +2027,10 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Assign(assign) = expression.node { - assert_eq!( - assign.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&assign.object, "a"); if let Expression::Binary(binary) = assign.value.node { assert_eq!(binary.op, BinOp::Add); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(1)) diff --git a/src/compiler/semantic.rs b/src/compiler/semantic.rs index 4594b38..2d0f8b1 100644 --- a/src/compiler/semantic.rs +++ b/src/compiler/semantic.rs @@ -1,9 +1,47 @@ -use super::CompileError; use super::ast::syntax::*; use super::symbol::SymbolTable; use super::typing::TypeContext; use crate::Environment; +#[derive(Debug, Clone)] +pub struct SemanticError { + pub span: Span, + pub kind: ErrKind, +} + +impl SemanticError { + pub fn new(span: Span, kind: ErrKind) -> SemanticError { + SemanticError { span, kind } + } + + pub fn with_span(mut self, span: Span) -> Self { + self.span = span; + self + } +} + +impl From for SemanticError { + fn from(value: ErrKind) -> Self { + SemanticError { + span: Span::new(0, 0), + kind: value, + } + } +} + +#[derive(Debug, Clone)] +pub enum ErrKind { + UndefinedVariable(String), + BreakOutsideLoop, + ContinueOutsideLoop, +} + +impl ErrKind { + pub fn with_span(self, span: Span) -> SemanticError { + SemanticError { span, kind: self } + } +} + pub struct SemanticAnalyzer<'a> { type_cx: &'a TypeContext, loop_depth: usize, @@ -22,23 +60,23 @@ impl<'a> SemanticAnalyzer<'a> { /// 对Program进行语义检查 pub fn analyze_program( &mut self, - program: &mut Program, + program: &Program, env: &Environment, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { // 第一阶段:收集环境变量 for name in env.symbols.keys() { self.symbol_table.insert(name.clone(), ()); } // 第二阶段:分析所有语句 - for stmt in &mut program.stmts { + for stmt in &program.stmts { self.analyze_statement(stmt)?; } Ok(()) } /// 分析语句并推断类型 - fn analyze_statement(&mut self, stmt: &StatementNode) -> Result<(), CompileError> { + fn analyze_statement(&mut self, stmt: &StatementNode) -> Result<(), SemanticError> { let span = stmt.span; match &stmt.node { @@ -61,16 +99,19 @@ impl<'a> SemanticAnalyzer<'a> { } } - fn analyze_let_statement(&mut self, let_stmt: &LetStatement) -> Result<(), CompileError> { - let LetStatement { name, ty, value } = let_stmt; + fn analyze_let_statement(&mut self, let_stmt: &LetStatement) -> Result<(), SemanticError> { + let LetStatement { name, value, .. } = let_stmt; self.symbol_table.insert(name, ()); + if let Some(value) = value { + self.analyze_expression(value)?; + } Ok(()) } /// 分析If语句 - fn analyze_if_statement(&mut self, if_stmt: &IfStatement) -> Result<(), CompileError> { + fn analyze_if_statement(&mut self, if_stmt: &IfStatement) -> Result<(), SemanticError> { // 分析条件表达式 self.analyze_expression(&if_stmt.condition)?; @@ -85,7 +126,7 @@ impl<'a> SemanticAnalyzer<'a> { Ok(()) } - fn analyze_loop_statement(&mut self, loop_stmt: &LoopStatement) -> Result<(), CompileError> { + fn analyze_loop_statement(&mut self, loop_stmt: &LoopStatement) -> Result<(), SemanticError> { // 增加循环深度 self.loop_depth += 1; @@ -99,7 +140,10 @@ impl<'a> SemanticAnalyzer<'a> { } /// 分析While语句 - fn analyze_while_statement(&mut self, while_stmt: &WhileStatement) -> Result<(), CompileError> { + fn analyze_while_statement( + &mut self, + while_stmt: &WhileStatement, + ) -> Result<(), SemanticError> { // 分析条件表达式 self.analyze_expression(&while_stmt.condition)?; @@ -116,7 +160,7 @@ impl<'a> SemanticAnalyzer<'a> { } /// 分析For语句 - fn analyze_for_statement(&mut self, for_stmt: &ForStatement) -> Result<(), CompileError> { + fn analyze_for_statement(&mut self, for_stmt: &ForStatement) -> Result<(), SemanticError> { // 创建新的作用域 self.symbol_table.enter_scope(); @@ -144,24 +188,37 @@ impl<'a> SemanticAnalyzer<'a> { &mut self, pattern: &Pattern, expr: &ExpressionNode, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { self.analyze_expression(expr)?; + match pattern { + Pattern::Identifier(ident) => { + // 将变量添加到符号表中 + self.symbol_table.insert(ident.name(), ()); + } + Pattern::Tuple(patterns) => { + for pattern in patterns { + self.analyze_pattern(pattern, expr)?; + } + } + _ => {} + } + Ok(()) } /// 分析Break语句 - fn analyze_break_statement(&mut self, span: Span) -> Result<(), CompileError> { + fn analyze_break_statement(&mut self, span: Span) -> Result<(), SemanticError> { if self.loop_depth == 0 { - return Err(CompileError::BreakOutsideLoop { span }); + return Err(ErrKind::BreakOutsideLoop.with_span(span)); } Ok(()) } /// 分析Continue语句 - fn analyze_continue_statement(&mut self, span: Span) -> Result<(), CompileError> { + fn analyze_continue_statement(&mut self, span: Span) -> Result<(), SemanticError> { if self.loop_depth == 0 { - return Err(CompileError::ContinueOutsideLoop { span }); + return Err(ErrKind::ContinueOutsideLoop.with_span(span)); } Ok(()) } @@ -171,7 +228,7 @@ impl<'a> SemanticAnalyzer<'a> { &mut self, return_stmt: &ReturnStatement, _span: Span, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { if let Some(expr) = &return_stmt.value { self.analyze_expression(expr)?; } @@ -179,7 +236,7 @@ impl<'a> SemanticAnalyzer<'a> { } /// 分析代码块 - fn analyze_block(&mut self, block: &BlockStatement) -> Result<(), CompileError> { + fn analyze_block(&mut self, block: &BlockStatement) -> Result<(), SemanticError> { // 创建新的作用域 self.symbol_table.enter_scope(); @@ -195,7 +252,7 @@ impl<'a> SemanticAnalyzer<'a> { } /// 分析函数定义 - fn analyze_function_item(&mut self, func: &FunctionItem) -> Result<(), CompileError> { + fn analyze_function_item(&mut self, func: &FunctionItem) -> Result<(), SemanticError> { // 创建新的作用域 self.symbol_table.enter_scope(); @@ -213,74 +270,81 @@ impl<'a> SemanticAnalyzer<'a> { Ok(()) } - fn analyze_struct_item(&mut self, _item: &StructItem) -> Result<(), CompileError> { + fn analyze_struct_item(&mut self, _item: &StructItem) -> Result<(), SemanticError> { Ok(()) } /// 分析表达式并推断类型 - fn analyze_expression(&mut self, expr: &ExpressionNode) -> Result<(), CompileError> { - match &expr.node { - Expression::Identifier(ident) => self.anlyze_identifier_expression(ident)?, - Expression::Binary(expr) => self.analyze_binary_expression(expr)?, - Expression::Prefix(expr) => self.analyze_prefix_expression(expr)?, - Expression::Call(expr) => self.analyze_call_expression(expr)?, - Expression::Array(expr) => self.analyze_array_expression(expr)?, - Expression::Map(expr) => self.analyze_map_expression(expr)?, - Expression::IndexGet(expr) => self.analyze_index_get_expression(expr)?, - Expression::IndexSet(expr) => self.analyze_index_set_expression(expr)?, - Expression::PropertyGet(expr) => self.analyze_property_get_expression(expr)?, - Expression::PropertySet(expr) => self.analyze_property_set_expression(expr)?, - Expression::Assign(expr) => self.analyze_assign_expression(expr)?, - Expression::Range(expr) => self.analyze_range_expression(expr)?, - Expression::Slice(expr) => self.analyze_slice_expression(expr)?, - Expression::Try(expr) => self.analyze_try_expression(expr)?, - Expression::Await(expr) => self.analyze_await_expression(expr)?, - Expression::CallMethod(expr) => self.analyze_call_method_expression(expr)?, + fn analyze_expression(&mut self, expr: &ExpressionNode) -> Result<(), SemanticError> { + let ret = match &expr.node { + Expression::Identifier(ident) => self.anlyze_identifier_expression(ident), + Expression::Binary(expr) => self.analyze_binary_expression(expr), + Expression::Prefix(expr) => self.analyze_prefix_expression(expr), + Expression::Call(expr) => self.analyze_call_expression(expr), + Expression::Array(expr) => self.analyze_array_expression(expr), + Expression::Map(expr) => self.analyze_map_expression(expr), + Expression::IndexGet(expr) => self.analyze_index_get_expression(expr), + Expression::IndexSet(expr) => self.analyze_index_set_expression(expr), + Expression::PropertyGet(expr) => self.analyze_property_get_expression(expr), + Expression::PropertySet(expr) => self.analyze_property_set_expression(expr), + Expression::Assign(expr) => self.analyze_assign_expression(expr), + Expression::Range(expr) => self.analyze_range_expression(expr), + Expression::Slice(expr) => self.analyze_slice_expression(expr), + Expression::Try(expr) => self.analyze_try_expression(expr), + Expression::Await(expr) => self.analyze_await_expression(expr), + Expression::CallMethod(expr) => self.analyze_call_method_expression(expr), _ => { // 处理其他未实现的表达式类型 + Ok(()) } }; - Ok(()) + ret.map_err(|err| { + if err.span.is_zero() { + err.with_span(expr.span) + } else { + err + } + }) } fn anlyze_identifier_expression( &mut self, ident: &IdentifierExpression, - ) -> Result<(), CompileError> { - if !self.type_cx.type_is_defined(&ident.0) { - return Err(CompileError::UndefinedVariable { - name: ident.0.clone(), - }); + ) -> Result<(), SemanticError> { + if self.symbol_table.lookup(ident.name()).is_none() + && self.type_cx.get_function_def(ident.name()).is_none() + { + return Err(ErrKind::UndefinedVariable(ident.name().to_string()).into()); } Ok(()) } - fn analyze_binary_expression(&mut self, expr: &BinaryExpression) -> Result<(), CompileError> { - let lhs_ty = self.analyze_expression(&expr.lhs)?; - let rhs_ty = self.analyze_expression(&expr.rhs)?; + fn analyze_binary_expression(&mut self, expr: &BinaryExpression) -> Result<(), SemanticError> { + self.analyze_expression(&expr.lhs)?; + self.analyze_expression(&expr.rhs)?; Ok(()) } - fn analyze_prefix_expression(&mut self, expr: &PrefixExpression) -> Result<(), CompileError> { - let rhs_ty = self.analyze_expression(&expr.rhs)?; + fn analyze_prefix_expression(&mut self, expr: &PrefixExpression) -> Result<(), SemanticError> { + self.analyze_expression(&expr.rhs)?; Ok(()) } - fn analyze_call_expression(&mut self, expr: &CallExpression) -> Result<(), CompileError> { + fn analyze_call_expression(&mut self, expr: &CallExpression) -> Result<(), SemanticError> { self.analyze_expression(&expr.func)?; for arg in expr.args.iter() { - self.analyze_expression(&arg)?; + self.analyze_expression(arg)?; } Ok(()) } - fn analyze_array_expression(&mut self, expr: &ArrayExpression) -> Result<(), CompileError> { + fn analyze_array_expression(&mut self, expr: &ArrayExpression) -> Result<(), SemanticError> { // 为每个元素创建临时变量并分析类型 for elem in expr.0.iter() { self.analyze_expression(elem)?; @@ -289,7 +353,7 @@ impl<'a> SemanticAnalyzer<'a> { Ok(()) } - fn analyze_map_expression(&mut self, expr: &MapExpression) -> Result<(), CompileError> { + fn analyze_map_expression(&mut self, expr: &MapExpression) -> Result<(), SemanticError> { for (key, value) in expr.0.iter() { self.analyze_expression(key)?; self.analyze_expression(value)?; @@ -301,7 +365,7 @@ impl<'a> SemanticAnalyzer<'a> { fn analyze_index_get_expression( &mut self, expr: &IndexGetExpression, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { self.analyze_expression(&expr.object)?; self.analyze_expression(&expr.index)?; @@ -311,7 +375,7 @@ impl<'a> SemanticAnalyzer<'a> { fn analyze_index_set_expression( &mut self, expr: &IndexSetExpression, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { self.analyze_expression(&expr.object)?; self.analyze_expression(&expr.value)?; @@ -321,7 +385,7 @@ impl<'a> SemanticAnalyzer<'a> { fn analyze_property_get_expression( &mut self, expr: &PropertyGetExpression, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { self.analyze_expression(&expr.object)?; Ok(()) } @@ -329,21 +393,21 @@ impl<'a> SemanticAnalyzer<'a> { fn analyze_property_set_expression( &mut self, expr: &PropertySetExpression, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { self.analyze_expression(&expr.object)?; self.analyze_expression(&expr.value)?; Ok(()) } - fn analyze_assign_expression(&mut self, expr: &AssignExpression) -> Result<(), CompileError> { + fn analyze_assign_expression(&mut self, expr: &AssignExpression) -> Result<(), SemanticError> { self.analyze_expression(&expr.object)?; self.analyze_expression(&expr.value)?; Ok(()) } - fn analyze_range_expression(&mut self, expr: &RangeExpression) -> Result<(), CompileError> { + fn analyze_range_expression(&mut self, expr: &RangeExpression) -> Result<(), SemanticError> { if let Some(ref begin_expr) = expr.begin { self.analyze_expression(begin_expr)?; } @@ -355,20 +419,20 @@ impl<'a> SemanticAnalyzer<'a> { Ok(()) } - fn analyze_slice_expression(&mut self, expr: &SliceExpression) -> Result<(), CompileError> { + fn analyze_slice_expression(&mut self, expr: &SliceExpression) -> Result<(), SemanticError> { self.analyze_expression(&expr.object)?; self.analyze_range_expression(&expr.range.node)?; Ok(()) } - fn analyze_try_expression(&mut self, expr: &ExpressionNode) -> Result<(), CompileError> { + fn analyze_try_expression(&mut self, expr: &ExpressionNode) -> Result<(), SemanticError> { self.analyze_expression(expr)?; Ok(()) } - fn analyze_await_expression(&mut self, expr: &ExpressionNode) -> Result<(), CompileError> { + fn analyze_await_expression(&mut self, expr: &ExpressionNode) -> Result<(), SemanticError> { self.analyze_expression(expr)?; Ok(()) @@ -377,7 +441,7 @@ impl<'a> SemanticAnalyzer<'a> { fn analyze_call_method_expression( &mut self, expr: &CallMethodExpression, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { self.analyze_expression(&expr.object)?; // 分析方法参数 diff --git a/src/compiler/symbol.rs b/src/compiler/symbol.rs index 9a27f7c..8d47fa8 100644 --- a/src/compiler/symbol.rs +++ b/src/compiler/symbol.rs @@ -38,6 +38,14 @@ impl SymbolTable { } } +impl Clone for SymbolTable { + fn clone(&self) -> Self { + SymbolTable { + scopes: self.scopes.clone(), + } + } +} + #[derive(Debug)] struct Scope { variables: HashMap, @@ -50,3 +58,11 @@ impl Scope { } } } + +impl Clone for Scope { + fn clone(&self) -> Self { + Scope { + variables: self.variables.clone(), + } + } +} diff --git a/src/compiler/typing.rs b/src/compiler/typing.rs index 5781281..5e18a20 100644 --- a/src/compiler/typing.rs +++ b/src/compiler/typing.rs @@ -1,8 +1,9 @@ -use std::{collections::HashMap, default}; +use std::collections::HashMap; + use crate::{Environment, compiler::symbol::SymbolTable}; -use super::ast::{syntax::*, walker::Walker}; +use super::ast::syntax::*; #[derive(Debug, Clone)] pub struct TypeError { @@ -16,7 +17,9 @@ impl TypeError { } pub fn with_span(mut self, span: Span) -> Self { - self.span = span; + if !span.is_zero() { + self.span = span; + } self } } @@ -31,7 +34,7 @@ impl From for TypeError { } #[derive(Debug, Clone)] -enum ErrKind { +pub enum ErrKind { Message(String), UnresovledType(String), DuplicateName(String), @@ -57,6 +60,7 @@ pub enum Type { String, Array, Tuple, + Range, Enum(TypeId), Struct(TypeId), Function(Box), @@ -107,6 +111,7 @@ pub struct StructDef { pub fields: HashMap, } +#[derive(Debug, Clone)] pub struct TypeContext { type_defs: HashMap, name_to_id: HashMap, @@ -163,7 +168,7 @@ impl TypeContext { self.functions.get(name).map(|v| &**v) } - pub fn analyze_type_def(&mut self, stmts: &[StatementNode]) -> Result<(), TypeError> { + pub fn check_type_def(&mut self, stmts: &[StatementNode]) -> Result<(), TypeError> { // round 1, decl types for stmt in stmts { match &stmt.node { @@ -184,7 +189,8 @@ impl TypeContext { name: item.name.clone(), variants: Vec::new(), }), - ); + ) + .map_err(|err| err.with_span(stmt.span()))?; } _ => {} } @@ -193,12 +199,14 @@ impl TypeContext { // round 2, resolve types for stmt in stmts { match &stmt.node { - Statement::Item(ItemStatement::Function(func)) => {} + Statement::Item(ItemStatement::Function(func)) => { + self.check_function_item(func)?; + } Statement::Item(ItemStatement::Struct(item)) => { - self.analyze_struct_item(item)?; + self.check_struct_item(item)?; } Statement::Item(ItemStatement::Enum(item)) => { - self.analyze_enum_item(item)?; + self.check_enum_item(item)?; } _ => { continue; @@ -209,7 +217,7 @@ impl TypeContext { Ok(()) } - fn analyze_function_item(&mut self, item: &FunctionItem) -> Result<(), TypeError> { + fn check_function_item(&mut self, item: &FunctionItem) -> Result<(), TypeError> { let FunctionItem { name, params, @@ -230,7 +238,7 @@ impl TypeContext { let param_type = param .ty .as_ref() - .map(|ty| self.resolve_type(&ty)) + .map(|ty| self.resolve_type(ty)) .transpose()?; func.params.push((param.name.clone(), param_type)); } @@ -240,7 +248,7 @@ impl TypeContext { Ok(()) } - fn analyze_struct_item(&mut self, item: &StructItem) -> Result { + fn check_struct_item(&mut self, item: &StructItem) -> Result { let StructItem { name, fields } = item; let type_id = self.name_to_id.get(name).unwrap(); @@ -266,7 +274,7 @@ impl TypeContext { Ok(*type_id) } - fn analyze_enum_item(&mut self, item: &EnumItem) -> Result { + fn check_enum_item(&mut self, item: &EnumItem) -> Result { let EnumItem { name, variants } = item; let type_id = self.name_to_id.get(name).unwrap(); @@ -278,7 +286,7 @@ impl TypeContext { let variant_type = variant .as_ref() - .map(|ty| self.resolve_type(&ty)) + .map(|ty| self.resolve_type(ty)) .transpose()?; enum_variants.push((name.to_string(), variant_type)); @@ -374,7 +382,7 @@ impl<'a> TypeChecker<'a> { } for item in &program.stmts { - self.check_statement(&item)?; + self.check_statement(item)?; } Ok(()) } @@ -388,7 +396,7 @@ impl<'a> TypeChecker<'a> { Statement::For(for_stmt) => self.check_for_statement(for_stmt), Statement::Loop(loop_stmt) => self.check_loop_statement(loop_stmt), Statement::Return(return_stmt) => self.check_return_statement(return_stmt), - Statement::Expression(expr) => self.analyze_expression(expr).map(|_| ()), + Statement::Expression(expr) => self.check_expression(expr).map(|_| ()), Statement::Empty => Ok(()), Statement::Break => Ok(()), Statement::Continue => Ok(()), @@ -410,7 +418,7 @@ impl<'a> TypeChecker<'a> { // 新增方法:检查条件语句 fn check_if_statement(&mut self, if_stmt: &IfStatement) -> Result<(), TypeError> { - let condition_type = self.analyze_expression(&if_stmt.condition)?; + let condition_type = self.check_expression(&if_stmt.condition)?; if condition_type != Type::Boolean && condition_type != Type::Any { return Err(ErrKind::TypeMismatch { @@ -429,7 +437,7 @@ impl<'a> TypeChecker<'a> { // 新增方法:检查循环语句 fn check_while_statement(&mut self, while_stmt: &WhileStatement) -> Result<(), TypeError> { - let condition_type = self.analyze_expression(&while_stmt.condition)?; + let condition_type = self.check_expression(&while_stmt.condition)?; if condition_type != Type::Boolean && condition_type != Type::Any { return Err(ErrKind::TypeMismatch { @@ -445,11 +453,29 @@ impl<'a> TypeChecker<'a> { // 新增方法:检查 for 循环语句 fn check_for_statement(&mut self, for_stmt: &ForStatement) -> Result<(), TypeError> { - self.analyze_expression(&for_stmt.iterable)?; + self.check_expression(&for_stmt.iterable)?; + self.check_pattern(&for_stmt.pat)?; self.check_block_statement(&for_stmt.body)?; Ok(()) } + fn check_pattern(&mut self, pattern: &Pattern) -> Result<(), TypeError> { + match pattern { + Pattern::Identifier(identifier) => { + self.symbols.insert(identifier.name(), Type::Any); + Ok(()) + } + Pattern::Tuple(tuple) => { + for pattern in tuple { + self.check_pattern(pattern)?; + } + Ok(()) + } + Pattern::Wildcard => Ok(()), + Pattern::Literal(literal) => Ok(()), + } + } + // 新增方法:检查无限循环语句 fn check_loop_statement(&mut self, loop_stmt: &LoopStatement) -> Result<(), TypeError> { self.check_block_statement(&loop_stmt.body)?; @@ -459,7 +485,7 @@ impl<'a> TypeChecker<'a> { // 新增方法:检查返回语句 fn check_return_statement(&mut self, return_stmt: &ReturnStatement) -> Result<(), TypeError> { if let Some(expr) = &return_stmt.value { - let return_ty = self.analyze_expression(expr)?; + let return_ty = self.check_expression(expr)?; if let Some(expected_ty) = &self.current_function_return_type { if return_ty != *expected_ty { @@ -490,7 +516,7 @@ impl<'a> TypeChecker<'a> { for param in &func_item.params { let param_type = match param.ty.as_ref() { - Some(ty) => self.type_cx.resolve_type(&ty)?, + Some(ty) => self.type_cx.resolve_type(ty)?, None => Type::Any, }; @@ -513,78 +539,86 @@ impl<'a> TypeChecker<'a> { } fn check_let_statement(&mut self, let_stmt: &LetStatement) -> Result<(), TypeError> { - let ty = let_stmt + let decl_ty = let_stmt .ty .as_ref() - .map(|ty| self.type_cx.resolve_type(&ty)) + .map(|ty| self.type_cx.resolve_type(ty)) + .transpose()?; + + let value_ty = let_stmt + .value + .as_ref() + .map(|expr| self.check_expression(expr)) .transpose()?; - if let Some(expr) = let_stmt.value.as_ref() { - let value_type = self.analyze_expression(expr)?; - if ty.is_some() && (ty.as_ref() != Some(&value_type)) { - return Err(ErrKind::TypeMismatch { - expected: ty.unwrap(), - actual: value_type, + let ty = match (decl_ty, value_ty) { + (Some(decl_ty), Some(value_ty)) => { + if decl_ty != value_ty { + return Err(ErrKind::TypeMismatch { + expected: decl_ty, + actual: value_ty, + } + .with_span(let_stmt.value.as_ref().unwrap().span)); + } else { + decl_ty } - .with_span(expr.span)); } - } + (Some(decl_ty), None) => decl_ty, + (None, Some(value_ty)) => value_ty, + (None, None) => Type::Any, + }; + + self.symbols.insert(&let_stmt.name, ty); Ok(()) } - fn analyze_expression(&mut self, expr: &ExpressionNode) -> Result { + fn check_expression(&mut self, expr: &ExpressionNode) -> Result { let ret = match &expr.node { - Expression::Literal(lit) => self.analyze_literal(lit), - Expression::Identifier(id) => self.analyze_identifier(id), - Expression::Binary(bin) => self.analyze_binary(bin), - Expression::Prefix(prefix) => self.analyze_prefix(prefix), - Expression::Call(call) => self.analyze_call(call), + Expression::Literal(lit) => self.check_literal(lit), + Expression::Identifier(id) => self.check_identifier(id), + Expression::Binary(bin) => self.check_binary(bin), + Expression::Prefix(prefix) => self.check_prefix(prefix), + Expression::Call(call) => self.check_call(call), Expression::Environment(env) => Ok(Type::String), - Expression::Path(path) => self.analyze_path(path), - Expression::Tuple(tuple) => self.analyze_tuple(tuple), - Expression::Array(arr) => self.analyze_array(arr), + Expression::Path(path) => self.check_path(path), + Expression::Tuple(tuple) => self.check_tuple(tuple), + Expression::Array(arr) => self.check_array(arr), Expression::Map(map) => Ok(Type::Any), // 暂定Map类型为Any - Expression::Closure(closure) => self.analyze_closure(closure), + Expression::Closure(closure) => self.check_closure(closure), Expression::Range(range) => Ok(Type::Any), // 暂定Range类型为Any - Expression::Slice(slice) => self.analyze_slice(slice), - Expression::Assign(assign) => self.analyze_assign(assign), - Expression::IndexGet(index) => self.analyze_index_get(index), - Expression::IndexSet(index) => self.analyze_index_set(index), - Expression::PropertyGet(prop) => self.analyze_property_get(prop), - Expression::PropertySet(prop) => self.analyze_property_set(prop), - Expression::CallMethod(call) => self.analyze_call_method(call), - Expression::StructExpr(struct_) => self.analyze_struct_expr(struct_), - Expression::Await(expr) => self.analyze_expression(expr), - Expression::Try(expr) => self.analyze_expression(expr), + Expression::Slice(slice) => self.check_slice(slice), + Expression::Assign(assign) => self.check_assign(assign), + Expression::IndexGet(index) => self.check_index_get(index), + Expression::IndexSet(index) => self.check_index_set(index), + Expression::PropertyGet(prop) => self.check_property_get(prop), + Expression::PropertySet(prop) => self.check_property_set(prop), + Expression::CallMethod(call) => self.check_call_method(call), + Expression::StructExpr(struct_) => self.check_struct_expr(struct_), + Expression::Await(expr) => self.check_expression(expr), + Expression::Try(expr) => self.check_expression(expr), // _ => Err(ErrKind::Message(format!("Unsupported expression: {:?}", expr.node)).into()), }; - ret.map_err(|err| { - if !err.span.is_empty() { - return err.with_span(expr.span); - } else { - err - } - }) + ret.map_err(|err| err.with_span(expr.span)) } - fn analyze_path(&self, path: &PathExpression) -> Result { + fn check_path(&self, path: &PathExpression) -> Result { // 路径表达式类型解析逻辑 Ok(Type::Any) // 暂定返回Any类型 } - fn analyze_tuple(&mut self, tuple: &TupleExpression) -> Result { + fn check_tuple(&mut self, tuple: &TupleExpression) -> Result { // 元组类型解析逻辑 Ok(Type::Tuple) } - fn analyze_array(&mut self, arr: &ArrayExpression) -> Result { + fn check_array(&mut self, arr: &ArrayExpression) -> Result { // 数组类型解析逻辑 Ok(Type::Array) } - fn analyze_closure(&mut self, closure: &ClosureExpression) -> Result { + fn check_closure(&mut self, closure: &ClosureExpression) -> Result { // // 闭包类型解析逻辑 // Ok(Type::Function(Box::new(FunctionDef { // name: "".to_string(), @@ -595,63 +629,60 @@ impl<'a> TypeChecker<'a> { Ok(Type::Any) // 临时返回 Any 类型 } - fn analyze_slice(&mut self, slice: &SliceExpression) -> Result { + fn check_slice(&mut self, slice: &SliceExpression) -> Result { // 切片类型解析逻辑 Ok(Type::Array) } - fn analyze_assign(&mut self, assign: &AssignExpression) -> Result { + fn check_assign(&mut self, assign: &AssignExpression) -> Result { // 赋值表达式类型解析逻辑 - self.analyze_expression(&assign.value) + self.check_expression(&assign.value) } - fn analyze_index_get(&mut self, index: &IndexGetExpression) -> Result { + fn check_index_get(&mut self, index: &IndexGetExpression) -> Result { // 索引获取类型解析逻辑 Ok(Type::Any) // 暂定返回Any类型 } - fn analyze_index_set(&mut self, index: &IndexSetExpression) -> Result { + fn check_index_set(&mut self, index: &IndexSetExpression) -> Result { // 索引设置类型解析逻辑 - self.analyze_expression(&index.value) + self.check_expression(&index.value) } - fn analyze_property_get(&mut self, prop: &PropertyGetExpression) -> Result { + fn check_property_get(&mut self, prop: &PropertyGetExpression) -> Result { // 属性获取类型解析逻辑 Ok(Type::Any) // 暂定返回Any类型 } - fn analyze_property_set(&mut self, prop: &PropertySetExpression) -> Result { + fn check_property_set(&mut self, prop: &PropertySetExpression) -> Result { // 属性设置类型解析逻辑 - self.analyze_expression(&prop.value) + self.check_expression(&prop.value) } - fn analyze_call_method(&mut self, call: &CallMethodExpression) -> Result { + fn check_call_method(&mut self, call: &CallMethodExpression) -> Result { // 调用方法类型解析逻辑 Ok(Type::Any) // 暂定返回Any类型 } - fn analyze_struct_expr(&mut self, struct_expr: &StructExpression) -> Result { - match self.type_cx.get_type_def(&struct_expr.name.node()) { + fn check_struct_expr(&mut self, struct_expr: &StructExpression) -> Result { + match self.type_cx.get_type_def(struct_expr.name.node()) { Some(TypeDef::Struct(struct_def)) => { for field in &struct_expr.fields { - let field_type = self.analyze_expression(&field.value)?; - match struct_def.fields.get(&field.name.node) { - Some(expected_type) => { - if field_type != *expected_type { - return Err(TypeError::new( - field.value.span(), - ErrKind::Message(format!( - "Expected type {:?} for field {:?}, found {:?}", - expected_type, field.name, field_type - )), - )); - } + let field_type = self.check_expression(&field.value)?; + if let Some(expected_type) = struct_def.fields.get(&field.name.node) { + if field_type != *expected_type && field_type != Type::Any { + return Err(TypeError::new( + field.value.span(), + ErrKind::Message(format!( + "Expected type {:?} for field {:?}, found {:?}", + expected_type, field.name, field_type + )), + )); } - None => {} } } - Ok(self.type_cx.get_type(&struct_expr.name.node()).unwrap()) + Ok(self.type_cx.get_type(struct_expr.name.node()).unwrap()) } Some(ty) => Err(TypeError::new( struct_expr.name.span(), @@ -661,7 +692,7 @@ impl<'a> TypeChecker<'a> { } } - fn analyze_literal(&self, lit: &LiteralExpression) -> Result { + fn check_literal(&self, lit: &LiteralExpression) -> Result { match lit { LiteralExpression::Null => Ok(Type::Any), LiteralExpression::Boolean(_) => Ok(Type::Boolean), @@ -672,16 +703,16 @@ impl<'a> TypeChecker<'a> { } } - fn analyze_identifier(&self, id: &IdentifierExpression) -> Result { - match self.symbols.lookup(&id.0) { + fn check_identifier(&self, ident: &IdentifierExpression) -> Result { + match self.symbols.lookup(ident.name()) { Some(ty) => Ok(ty.clone()), - None => Err(ErrKind::Message(format!("Undefined identifier: {}", id.0)).into()), + None => Err(ErrKind::Message(format!("Undefined identifier: {}", ident.name())).into()), } } - fn analyze_binary(&mut self, bin: &BinaryExpression) -> Result { - let lhs_type = self.analyze_expression(&bin.lhs)?; - let rhs_type = self.analyze_expression(&bin.rhs)?; + fn check_binary(&mut self, bin: &BinaryExpression) -> Result { + let lhs_type = self.check_expression(&bin.lhs)?; + let rhs_type = self.check_expression(&bin.rhs)?; // Handle Type::Any as compatible with any type if lhs_type == Type::Any || rhs_type == Type::Any { @@ -700,7 +731,12 @@ impl<'a> TypeChecker<'a> { .into()) } } - BinOp::Equal | BinOp::NotEqual => { + BinOp::Equal + | BinOp::NotEqual + | BinOp::Less + | BinOp::LessEqual + | BinOp::Greater + | BinOp::GreaterEqual => { if lhs_type == rhs_type { Ok(Type::Boolean) } else { @@ -711,12 +747,34 @@ impl<'a> TypeChecker<'a> { .into()) } } + BinOp::LogicAnd | BinOp::LogicOr => { + if lhs_type == Type::Boolean && rhs_type == Type::Boolean { + Ok(Type::Boolean) + } else { + Err(ErrKind::Message(format!( + "Type mismatch in logical operation: {:?} and {:?}", + lhs_type, rhs_type + )) + .into()) + } + } + BinOp::Range | BinOp::RangeInclusive => { + if lhs_type == Type::Integer && rhs_type == Type::Integer { + Ok(Type::Range) + } else { + Err(ErrKind::Message(format!( + "Type mismatch in range operation: {:?} and {:?}", + lhs_type, rhs_type + )) + .into()) + } + } _ => Err(ErrKind::Message(format!("Unsupported binary operator: {:?}", bin.op)).into()), } } - fn analyze_prefix(&mut self, prefix: &PrefixExpression) -> Result { - let rhs_type = self.analyze_expression(&prefix.rhs)?; + fn check_prefix(&mut self, prefix: &PrefixExpression) -> Result { + let rhs_type = self.check_expression(&prefix.rhs)?; // Handle Type::Any as compatible with any type if rhs_type == Type::Any { @@ -749,8 +807,8 @@ impl<'a> TypeChecker<'a> { } } - fn analyze_call(&mut self, call: &CallExpression) -> Result { - let func_type = self.analyze_expression(&call.func)?; + fn check_call(&mut self, call: &CallExpression) -> Result { + let func_type = self.check_expression(&call.func)?; // Handle Type::Any as compatible with any type if func_type == Type::Any { @@ -771,7 +829,7 @@ impl<'a> TypeChecker<'a> { // Check each argument type for (i, (param_name, param_type)) in func_def.params.iter().enumerate() { - let arg_type = self.analyze_expression(&call.args[i])?; + let arg_type = self.check_expression(&call.args[i])?; if let Some(expected_type) = param_type { // Handle Type::Any as compatible with any type if *expected_type != Type::Any && arg_type != *expected_type { diff --git a/src/error.rs b/src/error.rs index bdcdd6f..d58d5cb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,18 +1,18 @@ use crate::{compiler::CompileError, runtime::RuntimeError}; #[derive(Debug)] -pub enum Error { - Compile(CompileError), +pub enum Error<'i> { + Compile(CompileError<'i>), Runtime(RuntimeError), } -impl From for Error { - fn from(error: CompileError) -> Self { +impl<'i> From> for Error<'i> { + fn from(error: CompileError<'i>) -> Self { Error::Compile(error) } } -impl From for Error { +impl From for Error<'_> { fn from(error: RuntimeError) -> Self { Error::Runtime(error) } diff --git a/src/runtime/object/map.rs b/src/runtime/object/map.rs index d8f57f1..207cf3a 100644 --- a/src/runtime/object/map.rs +++ b/src/runtime/object/map.rs @@ -27,7 +27,9 @@ where } fn index_set(&mut self, index: &Value, value: ValueRef) -> Result<(), RuntimeError> { - if let (Some(key), Some(value)) = (index.downcast_ref::(), value.value().downcast_ref::()) { + if let (Some(key), Some(value)) = + (index.downcast_ref::(), value.value().downcast_ref::()) + { self.insert(key.clone(), value.clone()); return Ok(()); } diff --git a/tests/test_embed.rs b/tests/test_embed.rs index ae4f864..acc148e 100644 --- a/tests/test_embed.rs +++ b/tests/test_embed.rs @@ -1,10 +1,10 @@ mod utils; use std::sync::Arc; -use evalit::{Environment, Error, Module, Object, RuntimeError, VM, Value, ValueRef, compile}; +use evalit::{Environment, Module, Object, RuntimeError, VM, Value, ValueRef, compile}; use utils::init_logger; -fn run_vm(program: Arc, env: Environment) -> Result { +fn run_vm(program: Arc, env: Environment) -> Result { let mut vm = VM::new(program, env); #[cfg(not(feature = "async"))] let ret = vm.run().unwrap();