diff --git a/examples/scripting.rs b/examples/scripting.rs index 02610e5..56bef81 100644 --- a/examples/scripting.rs +++ b/examples/scripting.rs @@ -136,11 +136,9 @@ mod scripting { Ok(None) } - None => { - Err(evalit::RuntimeError::invalid_argument::( - 0, &args[0], - )) - } + None => Err(evalit::RuntimeError::invalid_argument::( + 0, &args[0], + )), } } @@ -161,11 +159,9 @@ mod scripting { Ok(None) } - _ => { - Err(evalit::RuntimeError::invalid_argument::( - 0, &args[0], - )) - } + _ => Err(evalit::RuntimeError::invalid_argument::( + 0, &args[0], + )), } } @@ -180,11 +176,9 @@ mod scripting { Ok(None) } - None => { - Err(evalit::RuntimeError::invalid_argument::( - 0, &args[0], - )) - } + None => Err(evalit::RuntimeError::invalid_argument::( + 0, &args[0], + )), } } @@ -195,9 +189,7 @@ mod scripting { Ok(Some(ValueRef::new(self.remote_addr.clone()))) } - _ => { - Err(evalit::RuntimeError::missing_method::(method)) - } + _ => Err(evalit::RuntimeError::missing_method::(method)), } } } diff --git a/src/compiler.rs b/src/compiler.rs index 868d6d6..381c930 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -5,6 +5,7 @@ mod lowering; mod parser; mod regalloc; mod semantic; +mod symbol; mod typing; use std::collections::HashMap; @@ -13,121 +14,158 @@ use std::sync::Arc; use ir::builder::{InstBuilder, IrBuilder}; use ir::instruction::IrUnit; use log::debug; -use typing::TypeContext; +use typing::{TypeChecker, TypeContext, TypeError}; use crate::Environment; use crate::bytecode::{Module, Register}; -use ast::syntax::{Span, Type}; +use crate::compiler::ast::syntax::Span; +use crate::compiler::symbol::SymbolTable; use parser::ParseError; use codegen::Codegen; -use lowering::{ASTLower, SymbolTable}; -use semantic::SemanticAnalyzer; +use lowering::ASTLower; +use semantic::{SemanticAnalyzer, SemanticError}; -pub fn compile(script: &str, env: &crate::Environment) -> Result, CompileError> { +pub fn compile<'i>( + script: &'i str, + env: &crate::Environment, +) -> Result, CompileError<'i>> { Compiler::new().compile(script, env) } -#[derive(Debug)] -pub enum CompileError { - Parse(ParseError), - Semantics(String), - UndefinedVariable { - name: String, - }, - UnknownType { - name: String, - }, - TypeMismatch { - expected: Box, - actual: Box, - span: Span, - }, - TypeInference(String), - TypeCheck(String), - ArgumentCountMismatch { - expected: usize, - actual: usize, - }, - NotCallable { - ty: Type, - span: Span, - }, - Unreachable, - BreakOutsideLoop { - span: Span, - }, - ContinueOutsideLoop { - span: Span, - }, - ReturnOutsideFunction { - span: Span, - }, - InvalidOperation { - message: String, - }, -} - -impl CompileError { - pub fn type_mismatch(expected: Type, actual: Type, span: Span) -> Self { - CompileError::TypeMismatch { +#[derive(Debug, Clone, Copy)] +struct LineCol { + line: usize, + col: usize, +} + +impl From<(usize, usize)> for LineCol { + fn from(value: (usize, usize)) -> Self { + Self { + line: value.0, + col: value.1, + } + } +} + +#[derive(Debug, Clone)] +pub struct CompileError<'i> { + kind: ErrorKind, + span: &'i str, + line_col: LineCol, +} + +impl<'i> CompileError<'i> { + pub fn new(input: &'i str, error: ErrorKind) -> Self { + let (line_col, span) = match error.line_col(input) { + Some(line_col) => { + let span = error.span(); + (line_col, &input[span.start..span.end]) + } + None => (LineCol { line: 0, col: 0 }, &input[0..0]), + }; + Self { + kind: error, span, - expected: Box::new(expected), - actual: Box::new(actual), + line_col, + } + } +} + +impl<'i> std::fmt::Display for CompileError<'i> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Error on {}@{:?}, detail: {:?}", + self.span, self.line_col, self.kind + ) + } +} + +impl<'i> std::error::Error for CompileError<'i> {} + +#[derive(Debug, Clone)] +pub enum ErrorKind { + // Io(std::io::Error), + Parse(ParseError), + Type(TypeError), + Semantic(SemanticError), +} + +impl ErrorKind { + fn line_col(&self, input: &str) -> Option { + match self { + ErrorKind::Parse(ParseError { span, .. }) => span.line_col(input).map(Into::into), + ErrorKind::Type(TypeError { span, .. }) => span.line_col(input).map(Into::into), + ErrorKind::Semantic(SemanticError { span, .. }) => span.line_col(input).map(Into::into), + } + } + + fn span(&self) -> Span { + match self { + ErrorKind::Parse(ParseError { span, .. }) => *span, + ErrorKind::Type(TypeError { span, .. }) => *span, + ErrorKind::Semantic(SemanticError { span, .. }) => *span, } } } -impl From for CompileError { +impl From for ErrorKind { fn from(error: ParseError) -> Self { - CompileError::Parse(error) + ErrorKind::Parse(error) } } -impl std::fmt::Display for CompileError { +impl From for ErrorKind { + fn from(error: TypeError) -> Self { + ErrorKind::Type(error) + } +} + +impl From for ErrorKind { + fn from(error: SemanticError) -> Self { + ErrorKind::Semantic(error) + } +} + +impl std::fmt::Display for ErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - CompileError::Parse(error) => write!(f, "Parse error: {error}"), - CompileError::Semantics(message) => write!(f, "Semantics error: {message}"), - CompileError::UndefinedVariable { name } => { - write!(f, "Undefined variable `{name}`") - } - CompileError::UnknownType { name } => { - write!(f, "Unknow type `{name}`") - } - CompileError::TypeMismatch { - expected, - actual, - span, - } => write!( - f, - "Type mismatch: expected `{expected:?}`, actual `{actual:?}` at {span:?}" - ), - CompileError::TypeInference(message) => write!(f, "Type inference error: {message}"), - CompileError::TypeCheck(message) => write!(f, "Type check error: {message}"), - CompileError::ArgumentCountMismatch { expected, actual } => write!( - f, - "Argument count mismatch: expected {expected}, actual {actual}" - ), - CompileError::NotCallable { ty, span } => { - write!(f, "Not callable: `{ty:?}` at {span:?}") - } - CompileError::Unreachable => write!(f, "Unreachable"), - CompileError::BreakOutsideLoop { span } => write!(f, "Break outside loop at {span:?}"), - CompileError::ContinueOutsideLoop { span } => { - write!(f, "Continue outside loop at {span:?}") - } - CompileError::ReturnOutsideFunction { span } => { - write!(f, "Return outside function at {span:?}") - } - CompileError::InvalidOperation { message } => { - write!(f, "Invalid operation, {message}") - } + ErrorKind::Parse(error) => write!(f, "Parse error: {error}"), + ErrorKind::Type(error) => write!(f, "Type error: {error:?}"), + ErrorKind::Semantic(error) => write!(f, "Semantic error: {error:?}"), } } } -impl std::error::Error for CompileError {} +pub struct FileId(usize); + +pub struct Context { + sources: Vec, +} + +impl Context { + pub fn new() -> Self { + Self { sources: vec![] } + } + + pub fn add_source(&mut self, source: String) -> FileId { + let id = FileId(self.sources.len()); + self.sources.push(source); + id + } + + // pub fn add_file(&mut self, file: impl AsRef) -> Result { + // let id = FileId(self.sources.len()); + // let content = std::fs::read_to_string(file.as_ref())?; + // self.sources.push(content); + // Ok(id) + // } + + pub fn get_source(&self, file: FileId) -> Option<&str> { + self.sources.get(file.0).map(|s| s.as_str()) + } +} pub struct Compiler {} @@ -142,24 +180,39 @@ impl Compiler { Self {} } - pub fn compile(&self, input: &str, env: &Environment) -> Result, CompileError> { + pub fn compile<'i>( + &self, + input: &'i str, + env: &Environment, + ) -> Result, CompileError<'i>> { + match self.compile_inner(input, env) { + Ok(module) => Ok(module), + Err(err) => Err(CompileError::new(input, err)), + } + } + + fn compile_inner(&self, input: &str, env: &Environment) -> Result, ErrorKind> { // 解析输入 - let mut ast = parser::parse_file(input)?; + let ast = parser::parse_file(input)?; debug!("AST: {ast:?}"); let mut type_cx = TypeContext::new(); - type_cx.process_env(env); + type_cx.check_type_def(&ast.stmts)?; // 语义分析 - let mut analyzer = SemanticAnalyzer::new(&mut type_cx); - analyzer.analyze_program(&mut ast, env)?; + let mut analyzer = SemanticAnalyzer::new(&type_cx); + analyzer.analyze_program(&ast, env)?; + + // 类型检查 + let mut checker = TypeChecker::new(&type_cx); + checker.check_program(&ast, env)?; // IR生成, AST -> IR let mut unit = IrUnit::new(); let builder: &mut dyn InstBuilder = &mut IrBuilder::new(&mut unit); let mut lower = ASTLower::new(builder, SymbolTable::new(), env, &type_cx); - lower.lower_program(ast)?; + lower.lower_program(ast); // code generation, IR -> bytecode let mut codegen = Codegen::new(&Register::general()); diff --git a/src/compiler/ast/grammar.pest b/src/compiler/ast/grammar.pest index 758ae72..166a501 100644 --- a/src/compiler/ast/grammar.pest +++ b/src/compiler/ast/grammar.pest @@ -47,11 +47,8 @@ enum_item = { "enum" ~ identifier ~ "{" ~ (enum_item_list ~ ","?)? ~ "}" } enum_item_list = { enum_field ~ ("," ~ enum_field)* } -enum_field = { simple_enum_field | tuple_enum_field } +enum_field = { identifier ~ ("(" ~ type_expression ~ ")")?} -simple_enum_field = { identifier ~ !"(" } - -tuple_enum_field = { identifier ~ "(" ~ type_expression ~ ("," ~ type_expression)* ~ ")" } /// Expression Statment expression_statement = { expression ~ ";" } diff --git a/src/compiler/ast/syntax.rs b/src/compiler/ast/syntax.rs index 4ca6355..94bc5aa 100644 --- a/src/compiler/ast/syntax.rs +++ b/src/compiler/ast/syntax.rs @@ -1,22 +1,18 @@ use std::{ - collections::HashMap, fmt, io::{self, Error, ErrorKind}, str::FromStr, }; -use pest::{RuleType, iterators::Pair}; - #[derive(Debug, Clone, PartialEq)] pub struct AstNode { pub node: T, pub span: Span, - pub ty: Type, } impl AstNode { - pub fn new(node: T, span: Span, ty: Type) -> AstNode { - AstNode { span, node, ty } + pub fn new(node: T, span: Span) -> AstNode { + AstNode { span, node } } pub fn span(&self) -> Span { @@ -26,10 +22,6 @@ impl AstNode { pub fn node(&self) -> &T { &self.node } - - pub fn ty(&self) -> &Type { - &self.ty - } } impl AsRef for AstNode { @@ -44,7 +36,7 @@ impl AsMut for AstNode { } } -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Default)] pub struct Span { pub start: usize, pub end: usize, @@ -55,78 +47,39 @@ impl Span { Span { start, end } } - pub fn from_pair(pair: &Pair) -> Span { - Span { - start: pair.as_span().start(), - end: pair.as_span().end(), - } - } - pub fn merge(&self, other: &Span) -> Span { Span { start: self.start.min(other.start), end: self.end.max(other.end), } } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum Type { - Any, - Boolean, - Byte, - Integer, - Float, - Char, - String, - Range, - Tuple(Vec), - Array(Box), - Map(Box), - Unknown, - UserDefined(String), - Decl(Declaration), -} - -impl Type { - pub fn is_boolean(&self) -> bool { - matches!(self, Type::Boolean) - } - pub fn is_numeric(&self) -> bool { - matches!(self, Type::Byte | Type::Integer | Type::Float) + pub fn is_zero(&self) -> bool { + self.start == 0 && self.start == self.end } - pub fn is_string(&self) -> bool { - matches!(self, Type::String) + pub fn is_empty(&self) -> bool { + self.start == self.end } - pub fn is_unknown(&self) -> bool { - matches!(self, Type::Unknown) - } - - pub fn is_collection(&self) -> bool { - matches!(self, Type::Array(_) | Type::Map(_)) - } - - pub fn is_any(&self) -> bool { - matches!(self, Type::Any) - } - - pub fn get_array_element_type(&self) -> Option<&Type> { - if let Type::Array(ty) = self { - Some(ty.as_ref()) - } else { - None + pub fn line_col(&self, source: &str) -> Option<(usize, usize)> { + let mut offset = 0; + for (l, line) in source.lines().enumerate() { + if self.start < offset + line.len() { + let off = self.start - offset; + let mut line_off = 0; + for (c, char) in line.char_indices() { + if line_off == off { + return Some((l + 1, c)); + } + line_off += char.len_utf8(); + } + } + + offset += line.len(); } - } - pub fn get_map_value_type(&self) -> Option<&Type> { - if let Type::Map(ty) = self { - Some(ty.as_ref()) - } else { - None - } + None } } @@ -167,7 +120,7 @@ pub struct BlockStatement(pub Vec); pub enum ItemStatement { Enum(EnumItem), Struct(StructItem), - Fn(FunctionItem), + Function(FunctionItem), } #[derive(Debug, Clone, PartialEq)] @@ -177,9 +130,9 @@ pub struct EnumItem { } #[derive(Debug, Clone, PartialEq)] -pub enum EnumVariant { - Simple(String), - Tuple(String, Vec), +pub struct EnumVariant { + pub name: String, + pub variant: Option, } #[derive(Debug, Clone, PartialEq)] @@ -256,8 +209,8 @@ pub enum TypeExpression { String, Array(Box), Tuple(Vec), - Generic(String, Vec), UserDefined(String), + Generic(String, Vec), Impl(Box), } @@ -334,7 +287,7 @@ pub struct SliceExpression { #[derive(Debug, Clone, PartialEq)] pub enum Pattern { Wildcard, - Identifier(String), + Identifier(IdentifierExpression), Literal(LiteralExpression), Tuple(Vec), } @@ -463,7 +416,17 @@ pub struct PostfixExpression { } #[derive(Debug, Clone, PartialEq)] -pub struct IdentifierExpression(pub String); +pub struct IdentifierExpression(pub AstNode); + +impl IdentifierExpression { + pub fn new(name: String, span: Span) -> Self { + IdentifierExpression(AstNode::new(name, span)) + } + + pub fn name(&self) -> &str { + self.0.node().as_str() + } +} #[derive(Debug, Clone, PartialEq)] pub enum LiteralExpression { @@ -533,44 +496,12 @@ pub struct CallMethodExpression { #[derive(Debug, Clone, PartialEq)] pub struct StructExpression { - pub name: String, + pub name: AstNode, pub fields: Vec, } #[derive(Debug, Clone, PartialEq)] pub struct StructExprField { - pub name: String, + pub name: AstNode, pub value: ExpressionNode, } - -#[derive(Debug, Clone, PartialEq)] -pub enum Declaration { - Function(FunctionDeclaration), - Struct(StructDeclaration), - Enum(EnumDeclaration), -} - -impl Declaration { - pub fn is_function(&self) -> bool { - matches!(self, Declaration::Function(_)) - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct FunctionDeclaration { - pub name: String, - pub params: Vec<(String, Option)>, - pub return_type: Option>, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct StructDeclaration { - pub name: String, - pub fields: HashMap, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct EnumDeclaration { - pub name: String, - pub variants: Vec<(String, Option)>, -} diff --git a/src/compiler/lowering.rs b/src/compiler/lowering.rs index a4f514b..cecac06 100644 --- a/src/compiler/lowering.rs +++ b/src/compiler/lowering.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; -use std::{cell::RefCell, collections::BTreeMap, rc::Rc}; -use super::CompileError; use super::ast::syntax::*; use super::ir::{builder::*, instruction::*}; -use super::typing::TypeContext; +use super::typing::{TypeContext, TypeDef}; +use crate::compiler::symbol::SymbolTable; +use crate::compiler::typing::{FunctionDef, StructDef}; use crate::{ Environment, bytecode::{FunctionId, Opcode, Primitive}, @@ -28,7 +28,7 @@ impl LoopContext { pub struct ASTLower<'a> { builder: &'a mut dyn InstBuilder, env: &'a Environment, - symbols: SymbolTable, + symbols: SymbolTable, loop_contexts: Vec, type_cx: &'a TypeContext, } @@ -36,7 +36,7 @@ pub struct ASTLower<'a> { impl<'a> ASTLower<'a> { pub fn new( builder: &'a mut dyn InstBuilder, - symbols: SymbolTable, + symbols: SymbolTable, env: &'a Environment, type_cx: &'a TypeContext, ) -> Self { @@ -49,7 +49,7 @@ impl<'a> ASTLower<'a> { } } - pub fn lower_program(&mut self, prog: Program) -> Result { + pub fn lower_program(&mut self, prog: Program) -> IrUnit { let mut unit = IrUnit::new(); let builder: &mut dyn InstBuilder = &mut IrBuilder::new(&mut unit); @@ -58,10 +58,8 @@ impl<'a> ASTLower<'a> { builder.switch_to_block(entry); // declare functions - for decl in self.type_cx.function_decls() { - if let Declaration::Function(func) = decl { - self.declare_function(func); - } + for func in self.type_cx.functions() { + self.declare_function(func); } let entry = self.create_block("main"); @@ -71,7 +69,7 @@ impl<'a> ASTLower<'a> { // split program into statements and items for stmt in prog.stmts { match stmt.node { - Statement::Item(ItemStatement::Fn(func)) => { + Statement::Item(ItemStatement::Function(func)) => { self.lower_function_item(func); } Statement::Item(_) => { @@ -87,7 +85,7 @@ impl<'a> ASTLower<'a> { // FIXME: This is a hack to make block not empty. self.builder.make_halt(); - Ok(unit) + unit } fn lower_statement(&mut self, statement: StatementNode) { @@ -140,12 +138,12 @@ impl<'a> ASTLower<'a> { self.builder.assign(dst, value); } - self.symbols.define(&name, Variable::new(dst)); + self.symbols.insert(&name, Variable::new(dst)); } fn lower_item_stmt(&mut self, item: ItemStatement) { match item { - ItemStatement::Fn(fn_item) => { + ItemStatement::Function(fn_item) => { self.lower_function_item(fn_item); } _ => unimplemented!("statement {:?}", item), @@ -187,7 +185,7 @@ impl<'a> ASTLower<'a> { Pattern::Identifier(ident) => { let dst = self.builder.alloc(); self.builder.assign(dst, value); - self.symbols.define(&ident, Variable::new(dst)); + self.symbols.insert(ident.name(), Variable::new(dst)); } Pattern::Tuple(pats) => { @@ -233,19 +231,19 @@ impl<'a> ASTLower<'a> { // loop body, get next value self.builder.switch_to_block(loop_body); - let new_symbols = self.symbols.new_scope(); - let old_symbols = std::mem::replace(&mut self.symbols, new_symbols); + + self.symbols.enter_scope(); let next = self.builder.call_property(next, "unwrap", vec![]); self.lower_pattern(pat, next); self.lower_block(body); - self.symbols = old_symbols; + self.symbols.leave_scope(); self.builder.br(loop_header); // done loop - self.level_loop_context(); + self.leave_loop_context(); self.builder.switch_to_block(after_blk); } @@ -259,16 +257,13 @@ impl<'a> ASTLower<'a> { self.builder.br(loop_body); self.builder.switch_to_block(loop_body); - let new_symbols = self.symbols.new_scope(); - let old_symbols = std::mem::replace(&mut self.symbols, new_symbols); self.lower_block(body); self.builder.br(loop_body); // done loop - self.level_loop_context(); - self.symbols = old_symbols; + self.leave_loop_context(); self.builder.switch_to_block(after_blk); } @@ -288,15 +283,11 @@ impl<'a> ASTLower<'a> { self.builder.br_if(cond, body_blk, after_blk); self.builder.switch_to_block(body_blk); - let new_symbols = self.symbols.new_scope(); - let old_symbols = std::mem::replace(&mut self.symbols, new_symbols); self.lower_block(body); self.builder.br(cond_blk); - self.level_loop_context(); - self.symbols = old_symbols; self.builder.switch_to_block(after_blk); } @@ -309,12 +300,13 @@ impl<'a> ASTLower<'a> { } fn lower_block(&mut self, block: BlockStatement) { - let new_symbols = self.symbols.new_scope(); - let old_symbols = std::mem::replace(&mut self.symbols, new_symbols); + self.symbols.enter_scope(); + for statement in block.0 { self.lower_statement(statement); } - self.symbols = old_symbols; + + self.symbols.leave_scope(); } fn lower_function_item(&mut self, fn_item: FunctionItem) -> Value { @@ -323,7 +315,7 @@ impl<'a> ASTLower<'a> { } = fn_item; let value = self.lower_function(Some(name.to_string()), params, body); - self.symbols.define(name, Variable::new(value)); + self.symbols.insert(name, Variable::new(value)); value } @@ -346,7 +338,7 @@ impl<'a> ASTLower<'a> { let mut func = IrFunction::new(func_id, func_sig); - let symbols = self.symbols.new_scope(); + let symbols = self.symbols.clone(); let mut func_builder = FunctionBuilder::new(self.builder.module_mut(), &mut func); @@ -361,7 +353,7 @@ impl<'a> ASTLower<'a> { func_lower .symbols - .define(param.name.as_str(), Variable::new(arg)); + .insert(param.name.as_str(), Variable::new(arg)); } func_lower.lower_block(body); @@ -458,11 +450,16 @@ impl<'a> ASTLower<'a> { let mut field_map = fields .into_iter() - .map(|StructExprField { name, value }| (name, self.lower_expression(value))) + .map(|StructExprField { name, value }| { + (name.node().clone(), self.lower_expression(value)) + }) .collect::>(); - let decl = self.type_cx.get_type_decl(&name).expect("struct not found"); - if let Declaration::Struct(StructDeclaration { + let decl = self + .type_cx + .get_type_def(name.node()) + .expect("struct not found"); + if let TypeDef::Struct(StructDef { fields: decl_fields, .. }) = decl @@ -508,19 +505,20 @@ impl<'a> ASTLower<'a> { .collect(); match func.node { - Expression::Identifier(IdentifierExpression(ref ident)) => { - match self.builder.module().find_function(ident) { + Expression::Identifier(ident) => { + match self.builder.module().find_function(ident.name()) { Some(func) => self.builder.call_function(func.id, args), - None => match self.symbols.get(ident) { + None => match self.symbols.lookup(ident.name()) { Some(var) => self.builder.make_call(var.0, args), - None => match self.env.get(ident) { + None => match self.env.get(ident.name()) { Some(EnvVariable::Function(_)) => { - let callable = - self.builder.load_external_variable(ident.to_string()); + let callable = self + .builder + .load_external_variable(ident.name().to_string()); self.builder.make_call_native(callable, args) } _ => { - panic!("unknown identifier: {ident}"); + panic!("unknown identifier: {}", ident.name()); } }, }, @@ -620,15 +618,15 @@ impl<'a> ASTLower<'a> { } fn lower_identifier(&mut self, identifier: IdentifierExpression) -> Value { - match self.symbols.get(&identifier.0) { - Some(Variable(addr)) => addr, + match self.symbols.lookup(identifier.name()) { + Some(Variable(addr)) => *addr, None => { - if let Some(_env) = self.env.get(&identifier.0) { + if let Some(_env) = self.env.get(identifier.name()) { return self .builder - .load_external_variable(identifier.0.to_string()); + .load_external_variable(identifier.name().to_string()); } - panic!("Undefined identifier: {}", identifier.0) + panic!("Undefined identifier: {}", identifier.name()) } } } @@ -677,8 +675,8 @@ impl<'a> ASTLower<'a> { } } - fn declare_function(&mut self, func: &FunctionDeclaration) -> FunctionId { - let FunctionDeclaration { name, params, .. } = func; + fn declare_function(&mut self, func: &FunctionDef) -> FunctionId { + let FunctionDef { name, params, .. } = func; let func_sig = FuncSignature::new( name.clone(), @@ -703,7 +701,7 @@ impl<'a> ASTLower<'a> { .push(LoopContext::new(break_point, continue_point)); } - fn level_loop_context(&mut self) { + fn leave_loop_context(&mut self) { self.loop_contexts.pop().expect("not in loop context"); } } @@ -717,41 +715,41 @@ impl Variable { } } -#[derive(Debug, Clone)] -pub struct SymbolNode { - parent: Option, - symbols: BTreeMap, -} - -#[derive(Debug, Clone)] -pub struct SymbolTable(Rc>); - -impl SymbolTable { - pub fn new() -> Self { - SymbolTable(Rc::new(RefCell::new(SymbolNode { - parent: None, - symbols: BTreeMap::new(), - }))) - } - - fn get(&self, name: &str) -> Option { - if let Some(value) = self.0.borrow().symbols.get(name) { - return Some(*value); - } - if let Some(parent) = &self.0.borrow().parent { - return parent.get(name); - } - None - } - - fn define(&mut self, name: impl Into, value: Variable) { - self.0.borrow_mut().symbols.insert(name.into(), value); - } - - fn new_scope(&self) -> SymbolTable { - SymbolTable(Rc::new(RefCell::new(SymbolNode { - parent: Some(self.clone()), - symbols: BTreeMap::new(), - }))) - } -} +// #[derive(Debug, Clone)] +// pub struct SymbolNode { +// parent: Option, +// symbols: BTreeMap, +// } + +// #[derive(Debug, Clone)] +// pub struct SymbolTable(Rc>); + +// impl SymbolTable { +// pub fn new() -> Self { +// SymbolTable(Rc::new(RefCell::new(SymbolNode { +// parent: None, +// symbols: BTreeMap::new(), +// }))) +// } + +// fn get(&self, name: &str) -> Option { +// if let Some(value) = self.0.borrow().symbols.get(name) { +// return Some(*value); +// } +// if let Some(parent) = &self.0.borrow().parent { +// return parent.get(name); +// } +// None +// } + +// fn insert(&mut self, name: impl Into, value: Variable) { +// self.0.borrow_mut().symbols.insert(name.into(), value); +// } + +// fn new_scope(&self) -> SymbolTable { +// SymbolTable(Rc::new(RefCell::new(SymbolNode { +// parent: Some(self.clone()), +// symbols: BTreeMap::new(), +// }))) +// } +// } diff --git a/src/compiler/parser.rs b/src/compiler/parser.rs index 0b7cac3..04ca663 100644 --- a/src/compiler/parser.rs +++ b/src/compiler/parser.rs @@ -8,29 +8,44 @@ use pest::{ use super::ast::syntax::*; -#[derive(Debug)] -pub struct ParseError(Box>); +#[derive(Debug, Clone)] +pub struct ParseError { + pub error: Box>, + pub span: Span, +} impl ParseError { pub fn with_message(span: pest::Span<'_>, message: impl ToString) -> Self { - Self(Box::new(pest::error::Error::new_from_span( + let error = pest::error::Error::::new_from_span( pest::error::ErrorVariant::CustomError { message: message.to_string(), }, span, - ))) + ); + ParseError { + error: Box::new(error), + span: Span::new(span.start(), span.end()), + } } } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) + write!(f, "{}", self.error) } } impl From> for ParseError { fn from(e: pest::error::Error) -> Self { - Self(Box::new(e)) + let span = match e.location { + pest::error::InputLocation::Pos(pos) => Span::new(pos, pos + 1), + pest::error::InputLocation::Span(span) => Span::new(span.0, span.1), + }; + + Self { + error: Box::new(e), + span, + } } } @@ -38,6 +53,12 @@ impl std::error::Error for ParseError {} type Result = std::result::Result; +impl Span { + fn from_pair(pair: &Pair) -> Self { + Span::new(pair.as_span().start(), pair.as_span().end()) + } +} + #[derive(pest_derive::Parser)] #[grammar = "compiler/ast/grammar.pest"] struct PestParser; @@ -164,7 +185,7 @@ fn parse_statement(pair: Pair) -> Result { _ => unreachable!("unknown statement: {pair:?}"), }; - Ok(AstNode::new(stmt, span, Type::Unknown)) + Ok(AstNode::new(stmt, span)) } fn parse_item_statement(pair: Pair) -> Result { @@ -173,7 +194,7 @@ fn parse_item_statement(pair: Pair) -> Result { match stat.as_rule() { Rule::enum_item => Ok(ItemStatement::Enum(parse_enum_item(stat))), Rule::struct_item => Ok(ItemStatement::Struct(parse_struct_item(stat))), - Rule::fn_item => parse_function_item(stat).map(ItemStatement::Fn), + Rule::fn_item => parse_function_item(stat).map(ItemStatement::Function), _ => unreachable!("unknown item statement: {stat:?}"), } } @@ -191,21 +212,15 @@ fn parse_enum_item(pair: Pair) -> EnumItem { } fn parse_enum_field(pair: Pair) -> EnumVariant { - let field = pair.into_inner().next().unwrap(); - match field.as_rule() { - Rule::simple_enum_field => EnumVariant::Simple(field.as_str().to_string()), - Rule::tuple_enum_field => parse_tuple_enum_field(field), - _ => unreachable!("{field:?}"), - } -} - -fn parse_tuple_enum_field(pair: Pair) -> EnumVariant { let mut pairs = pair.into_inner(); - let name = pairs.next().unwrap().as_str().to_string(); - let tuple = pairs.map(|item| parse_type_expression(item)).collect(); + let name = pairs.next().unwrap(); + let ty = pairs.next().map(|item| parse_type_expression(item)); - EnumVariant::Tuple(name, tuple) + EnumVariant { + name: name.as_str().to_string(), + variant: ty, + } } fn parse_struct_item(pair: Pair) -> StructItem { @@ -327,9 +342,8 @@ fn parse_if_statement(pair: Pair) -> Result { let span = Span::from_pair(&pair); match pair.as_rule() { Rule::block => parse_block(pair), - Rule::if_statement => parse_if_statement(pair).map(|item| { - BlockStatement(vec![AstNode::new(Statement::If(item), span, Type::Unknown)]) - }), + Rule::if_statement => parse_if_statement(pair) + .map(|item| BlockStatement(vec![AstNode::new(Statement::If(item), span)])), _ => unreachable!("unknown else_branch: {:?}", pair), } }) @@ -358,7 +372,7 @@ fn parse_pattern(pair: Pair) -> Result { match pat.as_rule() { Rule::wildcard_pattern => Ok(Pattern::Wildcard), - Rule::identifier => Ok(Pattern::Identifier(pat.as_str().to_string())), + Rule::identifier => Ok(Pattern::Identifier(parse_identifier(pat)?)), Rule::literal => Ok(Pattern::Literal(parse_literal(pat)?)), Rule::tuple_pattern => { let mut tuple = Vec::new(); @@ -451,7 +465,7 @@ fn parse_prefix(op: Pair, rhs: Result) -> Result, op: Pair) -> Result, op: Pair) -> Result, op: Pair) -> Result, op: Pair) -> Result unreachable!("unknown postfix: {:?}", op), }; - Ok(AstNode::new(expr, span, Type::Unknown)) + Ok(AstNode::new(expr, span)) } fn parse_struct_expression(pair: Pair) -> Result { @@ -693,6 +707,7 @@ fn parse_struct_expression(pair: Pair) -> Result { let mut pairs = pair.into_inner(); let name = pairs.next().unwrap(); + let name_span = Span::from_pair(&name); assert_eq!(name.as_rule(), Rule::identifier); let name = name.as_str().to_string(); @@ -703,15 +718,11 @@ fn parse_struct_expression(pair: Pair) -> Result { .collect::>>()?; let expr = StructExpression { - name: name.clone(), + name: AstNode::new(name, name_span), fields, }; - Ok(AstNode::new( - Expression::StructExpr(expr), - span, - Type::Unknown, - )) + Ok(AstNode::new(Expression::StructExpr(expr), span)) } fn parse_struct_expression_field(pair: Pair) -> Result { @@ -721,8 +732,7 @@ fn parse_struct_expression_field(pair: Pair) -> Result { let name = pairs.next().unwrap(); assert_eq!(name.as_rule(), Rule::identifier); - let name = name.as_str().to_string(); - + let name = AstNode::new(name.as_str().to_string(), Span::from_pair(&name)); let value = parse_expression(pairs.next().unwrap())?; Ok(StructExprField { name, value }) @@ -744,7 +754,7 @@ fn parse_atom(pair: Pair) -> Result { _ => unreachable!("unknown atom: {:?}", pair), }?; - Ok(AstNode::new(expr, span, Type::Unknown)) + Ok(AstNode::new(expr, span)) } fn parse_closure(pair: Pair) -> Result { @@ -769,7 +779,10 @@ fn parse_identifier(pair: Pair) -> Result { let mut pairs = pair.into_inner(); let name = pairs.next().unwrap(); - Ok(IdentifierExpression(name.as_str().to_string())) + Ok(IdentifierExpression(AstNode::new( + name.as_str().to_string(), + Span::from_pair(&name), + ))) } fn unescape_string(s: &str) -> String { @@ -932,6 +945,14 @@ fn parse_path_segment(pair: Pair) -> Result { mod test { use super::*; + fn check_identifier_expression(input: &ExpressionNode, expected: &str) { + if let Expression::Identifier(identifier) = input.node() { + assert_eq!(identifier.name(), expected); + } else { + panic!("Expected Identifier expression"); + } + } + #[test] fn test_literal_expression() { let input = r#"1"#; @@ -956,34 +977,22 @@ mod test { let input = r#"foo"#; let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); - assert_eq!( - expression.node, - Expression::Identifier(IdentifierExpression("foo".to_string())) - ); + check_identifier_expression(&expression, "foo"); let input = r#"foo_bar"#; let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); - assert_eq!( - expression.node, - Expression::Identifier(IdentifierExpression("foo_bar".to_string())) - ); + check_identifier_expression(&expression, "foo_bar"); let input = r#"_foo"#; let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); - assert_eq!( - expression.node, - Expression::Identifier(IdentifierExpression("_foo".to_string())) - ); + check_identifier_expression(&expression, "_foo"); let input = r#"_"#; let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); - assert_eq!( - expression.node, - Expression::Identifier(IdentifierExpression("_".to_string())) - ); + check_identifier_expression(&expression, "_"); } #[test] @@ -1090,10 +1099,7 @@ mod test { if let Some(value) = &return_stmt.value { if let Expression::Binary(binary) = &value.node { assert_eq!(binary.op, BinOp::Add); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("x".to_string())) - ); + check_identifier_expression(&binary.lhs, "x"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -1118,10 +1124,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::IndexGet(index) = expression.node { - assert_eq!( - index.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&index.object, "a"); assert_eq!( index.index.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -1138,10 +1141,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::Range); assert_eq!( @@ -1161,10 +1161,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::RangeInclusive); assert_eq!( @@ -1184,10 +1181,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::Range); assert!(range.begin.is_none()); @@ -1201,10 +1195,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::Range); assert_eq!( @@ -1221,10 +1212,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Slice(slice) = expression.node { - assert_eq!( - slice.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&slice.object, "a"); let range = &slice.range.node; assert_eq!(range.op, BinOp::RangeInclusive); assert!(range.begin.is_none()); @@ -1243,10 +1231,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Call(call) = expression.node { - assert_eq!( - call.func.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&call.func, "a"); assert_eq!(call.args.len(), 3); assert_eq!( call.args[0].node, @@ -1268,10 +1253,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Call(call) = expression.node { - assert_eq!( - call.func.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&call.func, "a"); assert_eq!(call.args.len(), 0); } else { panic!("Expected call expression"); @@ -1284,10 +1266,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Try(expr) = expression.node { - assert_eq!( - expr.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&expr, "a"); } else { panic!("Expected try expression"); } @@ -1299,10 +1278,7 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Await(expr) = expression.node { - assert_eq!( - expr.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&expr, "a"); } else { panic!("Expected await expression"); } @@ -1328,10 +1304,7 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Prefix(prefix) = expression.node { assert_eq!(prefix.op, PrefixOp::Not); - assert_eq!( - prefix.rhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&prefix.rhs, "a"); } else { panic!("Expected prefix expression"); } @@ -1525,7 +1498,7 @@ mod test { #[test] fn test_enum_item() { - let input = r#"enum A { AA, BB(int, float) }"#; + let input = r#"enum A { AA, BB(int) }"#; let item_statement = parse_statement_input(input).unwrap(); if let Statement::Item(ItemStatement::Enum(enum_item)) = item_statement.node { // 检查枚举名称 @@ -1535,17 +1508,22 @@ mod test { assert_eq!(enum_item.variants.len(), 2); // 检查第一个变体(简单变体) - assert_eq!(enum_item.variants[0], EnumVariant::Simple("AA".to_string())); - - // 检查第二个变体(元组变体) - if let EnumVariant::Tuple(name, types) = &enum_item.variants[1] { - assert_eq!(name, "BB"); - assert_eq!(types.len(), 2); - assert_eq!(types[0], TypeExpression::Integer); - assert_eq!(types[1], TypeExpression::Float); - } else { - panic!("Expected tuple variant"); - } + assert_eq!( + enum_item.variants[0], + EnumVariant { + name: "AA".to_string(), + variant: None + } + ); + + // 检查第二个变体(值变体) + assert_eq!( + enum_item.variants[1], + EnumVariant { + name: "BB".to_string(), + variant: Some(TypeExpression::Integer) + } + ); } else { panic!("Expected enum item statement"); } @@ -1578,7 +1556,7 @@ mod test { fn test_function_item() { let input = r#"fn A(a: int, b: float) -> int { return a + b; }"#; let item_statement = parse_statement_input(input).unwrap(); - if let Statement::Item(ItemStatement::Fn(function)) = item_statement.node { + if let Statement::Item(ItemStatement::Function(function)) = item_statement.node { // 检查函数名和参数 assert_eq!(function.name, "A"); assert_eq!(function.params.len(), 2); @@ -1596,14 +1574,8 @@ mod test { if let Some(value) = &return_stmt.value { if let Expression::Binary(binary) = &value.node { assert_eq!(binary.op, BinOp::Add); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); - assert_eq!( - binary.rhs.node, - Expression::Identifier(IdentifierExpression("b".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); + check_identifier_expression(&binary.rhs, "b"); } else { panic!("Expected binary expression in return value"); } @@ -1678,10 +1650,7 @@ mod test { // 检查条件 if let Expression::Binary(binary) = while_stmt.condition.node { assert_eq!(binary.op, BinOp::Equal); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -1704,7 +1673,12 @@ mod test { let statement = parse_statement_input(input).unwrap(); if let Statement::For(for_stmt) = statement.node { // 检查模式 - assert_eq!(for_stmt.pat, Pattern::Identifier("i".to_string())); + + if let Pattern::Identifier(ident) = for_stmt.pat { + assert_eq!(ident.name(), "i"); + } else { + panic!("Expected identifier pattern in for statement"); + } // 检查迭代器表达式 if let Expression::Binary(binary) = for_stmt.iterable.node { @@ -1737,10 +1711,7 @@ mod test { // 检查条件 if let Expression::Binary(binary) = if_stmt.condition.node { assert_eq!(binary.op, BinOp::Equal); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -1777,10 +1748,7 @@ mod test { // 检查条件 if let Expression::Binary(binary) = if_stmt.condition.node { assert_eq!(binary.op, BinOp::Equal); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(2)) @@ -1837,10 +1805,7 @@ mod test { // 检查左侧条件 (a > 0) if let Expression::Binary(left_binary) = binary.lhs.node { assert_eq!(left_binary.op, BinOp::Greater); - assert_eq!( - left_binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&left_binary.lhs, "a"); assert_eq!( left_binary.rhs.node, Expression::Literal(LiteralExpression::Integer(0)) @@ -1852,10 +1817,7 @@ mod test { // 检查右侧条件 (b < 10) if let Expression::Binary(right_binary) = binary.rhs.node { assert_eq!(right_binary.op, BinOp::Less); - assert_eq!( - right_binary.lhs.node, - Expression::Identifier(IdentifierExpression("b".to_string())) - ); + check_identifier_expression(&right_binary.lhs, "b"); assert_eq!( right_binary.rhs.node, Expression::Literal(LiteralExpression::Integer(10)) @@ -1945,10 +1907,7 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::PropertyGet(property_get) = expression.node { assert_eq!(property_get.property, "b"); - assert_eq!( - property_get.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&property_get.object, "a"); } else { panic!("Expected property get expression"); } @@ -1960,10 +1919,7 @@ mod test { assert_eq!(property_get.property, "c"); if let Expression::PropertyGet(inner_property_get) = property_get.object.node { assert_eq!(inner_property_get.property, "b"); - assert_eq!( - inner_property_get.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&inner_property_get.object, "a"); } else { panic!("Expected inner member expression"); } @@ -1979,10 +1935,7 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::PropertySet(property_set) = expression.node { assert_eq!(property_set.property, "b"); - assert_eq!( - property_set.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&property_set.object, "a"); assert_eq!( property_set.value.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -2004,10 +1957,7 @@ mod test { value, }) = expression.node { - assert_eq!( - object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&object, "a"); assert_eq!( index.node, Expression::Literal(LiteralExpression::Integer(1)) @@ -2029,10 +1979,7 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::CallMethod(call_method) = expression.node { assert_eq!(call_method.method, "b"); - assert_eq!( - call_method.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&call_method.object, "a"); assert_eq!(call_method.args.len(), 2); assert_eq!( call_method.args[0].node, @@ -2055,18 +2002,12 @@ mod test { let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::PropertySet(property_set) = expression.node { assert_eq!(property_set.property, "b"); - assert_eq!( - property_set.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&property_set.object, "a"); if let Expression::Binary(binary) = property_set.value.node { assert_eq!(binary.op, BinOp::Add); if let Expression::PropertyGet(property_get) = binary.lhs.node { assert_eq!(property_get.property, "b"); - assert_eq!( - property_get.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&property_get.object, "a"); } else { panic!("Expected PropertyGet on lhs"); } @@ -2086,16 +2027,10 @@ mod test { let pairs = PestParser::parse(Rule::expression, input).unwrap(); let expression = parse_expression_pairs(pairs).unwrap(); if let Expression::Assign(assign) = expression.node { - assert_eq!( - assign.object.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&assign.object, "a"); if let Expression::Binary(binary) = assign.value.node { assert_eq!(binary.op, BinOp::Add); - assert_eq!( - binary.lhs.node, - Expression::Identifier(IdentifierExpression("a".to_string())) - ); + check_identifier_expression(&binary.lhs, "a"); assert_eq!( binary.rhs.node, Expression::Literal(LiteralExpression::Integer(1)) diff --git a/src/compiler/semantic.rs b/src/compiler/semantic.rs index 8d69be5..2d0f8b1 100644 --- a/src/compiler/semantic.rs +++ b/src/compiler/semantic.rs @@ -1,51 +1,85 @@ -use super::CompileError; use super::ast::syntax::*; +use super::symbol::SymbolTable; use super::typing::TypeContext; use crate::Environment; +#[derive(Debug, Clone)] +pub struct SemanticError { + pub span: Span, + pub kind: ErrKind, +} + +impl SemanticError { + pub fn new(span: Span, kind: ErrKind) -> SemanticError { + SemanticError { span, kind } + } + + pub fn with_span(mut self, span: Span) -> Self { + self.span = span; + self + } +} + +impl From for SemanticError { + fn from(value: ErrKind) -> Self { + SemanticError { + span: Span::new(0, 0), + kind: value, + } + } +} + +#[derive(Debug, Clone)] +pub enum ErrKind { + UndefinedVariable(String), + BreakOutsideLoop, + ContinueOutsideLoop, +} + +impl ErrKind { + pub fn with_span(self, span: Span) -> SemanticError { + SemanticError { span, kind: self } + } +} + pub struct SemanticAnalyzer<'a> { - type_cx: &'a mut TypeContext, - // 当前函数的返回类型,用于检查return语句 - current_function_return_type: Option, - // 循环嵌套计数,用于检查break和continue语句 + type_cx: &'a TypeContext, loop_depth: usize, + symbol_table: SymbolTable<()>, } impl<'a> SemanticAnalyzer<'a> { - pub fn new(type_cx: &'a mut TypeContext) -> Self { + pub fn new(type_cx: &'a TypeContext) -> Self { SemanticAnalyzer { type_cx, - current_function_return_type: None, loop_depth: 0, + symbol_table: SymbolTable::new(), } } /// 对Program进行语义检查 pub fn analyze_program( &mut self, - program: &mut Program, + program: &Program, env: &Environment, - ) -> Result<(), CompileError> { + ) -> Result<(), SemanticError> { // 第一阶段:收集环境变量 for name in env.symbols.keys() { - self.type_cx.set_type(name.to_string(), Type::Any); + self.symbol_table.insert(name.clone(), ()); } - // 第二阶段:收集声明 - self.type_cx.analyze_type_decl(&program.stmts); - - // 第三阶段:分析所有语句 - for stmt in &mut program.stmts { + // 第二阶段:分析所有语句 + for stmt in &program.stmts { self.analyze_statement(stmt)?; } Ok(()) } /// 分析语句并推断类型 - fn analyze_statement(&mut self, stmt: &mut StatementNode) -> Result<(), CompileError> { + fn analyze_statement(&mut self, stmt: &StatementNode) -> Result<(), SemanticError> { let span = stmt.span; - match &mut stmt.node { + match &stmt.node { Statement::Expression(expr) => self.analyze_expression(expr).map(|_| ()), Statement::Let(let_stmt) => self.analyze_let_statement(let_stmt), Statement::If(if_stmt) => self.analyze_if_statement(if_stmt), @@ -57,7 +91,7 @@ impl<'a> SemanticAnalyzer<'a> { Statement::Break => self.analyze_break_statement(span), Statement::Continue => self.analyze_continue_statement(span), Statement::Empty => Ok(()), - Statement::Item(ItemStatement::Fn(func)) => self.analyze_function_item(func), + Statement::Item(ItemStatement::Function(func)) => self.analyze_function_item(func), Statement::Item(ItemStatement::Struct(item)) => self.analyze_struct_item(item), Statement::Item(ItemStatement::Enum(EnumItem { .. })) => { unimplemented!("EnumItem not implemented") @@ -65,84 +99,39 @@ impl<'a> SemanticAnalyzer<'a> { } } - fn analyze_let_statement(&mut self, let_stmt: &mut LetStatement) -> Result<(), CompileError> { - let LetStatement { name, ty, value } = let_stmt; - - let decl_ty = match ty { - Some(type_expr) => { - let decl_type = self.type_from_type_expr(type_expr)?; - Some(decl_type) - } - None => None, - }; - - let value_ty = match value { - Some(value) => { - let value_ty = self.analyze_expression(value)?; - Some(value_ty) - } - None => None, - }; + fn analyze_let_statement(&mut self, let_stmt: &LetStatement) -> Result<(), SemanticError> { + let LetStatement { name, value, .. } = let_stmt; - match (decl_ty, value_ty) { - (Some(decl_type), Some(value_ty)) => { - if decl_type != value_ty { - return Err(CompileError::type_mismatch( - decl_type, - value_ty, - value.clone().unwrap().span(), - )); - } - self.type_cx.set_type(name.clone(), decl_type); - } - (Some(decl_type), None) => { - self.type_cx.set_type(name.clone(), decl_type); - } - (None, Some(value_ty)) => { - self.type_cx.set_type(name.clone(), value_ty); - } - (None, None) => { - self.type_cx.set_type(name.clone(), Type::Any); - } + self.symbol_table.insert(name, ()); + if let Some(value) = value { + self.analyze_expression(value)?; } Ok(()) } /// 分析If语句 - fn analyze_if_statement(&mut self, if_stmt: &mut IfStatement) -> Result<(), CompileError> { + fn analyze_if_statement(&mut self, if_stmt: &IfStatement) -> Result<(), SemanticError> { // 分析条件表达式 - let condition_ty = self.analyze_expression(&mut if_stmt.condition)?; - - // TODO: 条件必须是布尔类型 - if condition_ty != Type::Boolean && condition_ty != Type::Any { - return Err(CompileError::type_mismatch( - Type::Boolean, - condition_ty.clone(), - if_stmt.condition.span, - )); - } + self.analyze_expression(&if_stmt.condition)?; // 分析then分支 - self.analyze_block(&mut if_stmt.then_branch)?; + self.analyze_block(&if_stmt.then_branch)?; // 分析else分支(如果有) - if let Some(else_branch) = &mut if_stmt.else_branch { + if let Some(else_branch) = &if_stmt.else_branch { self.analyze_block(else_branch)?; } Ok(()) } - fn analyze_loop_statement( - &mut self, - loop_stmt: &mut LoopStatement, - ) -> Result<(), CompileError> { + fn analyze_loop_statement(&mut self, loop_stmt: &LoopStatement) -> Result<(), SemanticError> { // 增加循环深度 self.loop_depth += 1; // 分析循环体 - self.analyze_block(&mut loop_stmt.body)?; + self.analyze_block(&loop_stmt.body)?; // 减少循环深度 self.loop_depth -= 1; @@ -153,25 +142,16 @@ impl<'a> SemanticAnalyzer<'a> { /// 分析While语句 fn analyze_while_statement( &mut self, - while_stmt: &mut WhileStatement, - ) -> Result<(), CompileError> { + while_stmt: &WhileStatement, + ) -> Result<(), SemanticError> { // 分析条件表达式 - let condition_ty = self.analyze_expression(&mut while_stmt.condition)?; - - // 条件必须是布尔类型 - if !condition_ty.is_boolean() { - return Err(CompileError::type_mismatch( - Type::Boolean, - condition_ty.clone(), - while_stmt.condition.span, - )); - } + self.analyze_expression(&while_stmt.condition)?; // 增加循环深度 self.loop_depth += 1; // 分析循环体 - self.analyze_block(&mut while_stmt.body)?; + self.analyze_block(&while_stmt.body)?; // 减少循环深度 self.loop_depth -= 1; @@ -180,66 +160,65 @@ impl<'a> SemanticAnalyzer<'a> { } /// 分析For语句 - fn analyze_for_statement(&mut self, for_stmt: &mut ForStatement) -> Result<(), CompileError> { + fn analyze_for_statement(&mut self, for_stmt: &ForStatement) -> Result<(), SemanticError> { // 创建新的作用域 - let old_env = std::mem::replace(self.type_cx, self.type_cx.clone()); + self.symbol_table.enter_scope(); // 分析迭代器表达式 // self.analyze_expression(&mut for_stmt.iterable)?; - self.analyze_pattern(&mut for_stmt.pat, &mut for_stmt.iterable)?; + self.analyze_pattern(&for_stmt.pat, &for_stmt.iterable)?; // 增加循环深度 self.loop_depth += 1; // 分析循环体 - self.analyze_block(&mut for_stmt.body)?; + self.analyze_block(&for_stmt.body)?; // 减少循环深度 self.loop_depth -= 1; // 恢复作用域 - let _ = std::mem::replace(self.type_cx, old_env); + self.symbol_table.leave_scope(); Ok(()) } fn analyze_pattern( &mut self, - pattern: &mut Pattern, - expr: &mut ExpressionNode, - ) -> Result<(), CompileError> { + pattern: &Pattern, + expr: &ExpressionNode, + ) -> Result<(), SemanticError> { self.analyze_expression(expr)?; - // 检查模式是否匹配表达式 match pattern { - Pattern::Wildcard => {} - Pattern::Identifier(id) => { - self.type_cx.set_type(id.clone(), Type::Any); + Pattern::Identifier(ident) => { + // 将变量添加到符号表中 + self.symbol_table.insert(ident.name(), ()); } - Pattern::Tuple(tuple) => { - for pat in tuple.iter_mut() { - self.analyze_pattern(pat, expr)?; + Pattern::Tuple(patterns) => { + for pattern in patterns { + self.analyze_pattern(pattern, expr)?; } } - Pattern::Literal(_literal) => {} + _ => {} } Ok(()) } /// 分析Break语句 - fn analyze_break_statement(&mut self, span: Span) -> Result<(), CompileError> { + fn analyze_break_statement(&mut self, span: Span) -> Result<(), SemanticError> { if self.loop_depth == 0 { - return Err(CompileError::BreakOutsideLoop { span }); + return Err(ErrKind::BreakOutsideLoop.with_span(span)); } Ok(()) } /// 分析Continue语句 - fn analyze_continue_statement(&mut self, span: Span) -> Result<(), CompileError> { + fn analyze_continue_statement(&mut self, span: Span) -> Result<(), SemanticError> { if self.loop_depth == 0 { - return Err(CompileError::ContinueOutsideLoop { span }); + return Err(ErrKind::ContinueOutsideLoop.with_span(span)); } Ok(()) } @@ -247,489 +226,229 @@ impl<'a> SemanticAnalyzer<'a> { /// 分析Return语句 fn analyze_return_statement( &mut self, - return_stmt: &mut ReturnStatement, + return_stmt: &ReturnStatement, _span: Span, - ) -> Result<(), CompileError> { - if let Some(expr) = &mut return_stmt.value { + ) -> Result<(), SemanticError> { + if let Some(expr) = &return_stmt.value { self.analyze_expression(expr)?; } Ok(()) } /// 分析代码块 - fn analyze_block(&mut self, block: &mut BlockStatement) -> Result<(), CompileError> { + fn analyze_block(&mut self, block: &BlockStatement) -> Result<(), SemanticError> { // 创建新的作用域 - let old_env = std::mem::replace(self.type_cx, self.type_cx.clone()); + self.symbol_table.enter_scope(); // 分析块中的所有语句 - for stmt in &mut block.0 { + for stmt in &block.0 { self.analyze_statement(stmt)?; } // 恢复作用域 - let _ = std::mem::replace(self.type_cx, old_env); + self.symbol_table.leave_scope(); Ok(()) } /// 分析函数定义 - fn analyze_function_item(&mut self, func: &mut FunctionItem) -> Result<(), CompileError> { + fn analyze_function_item(&mut self, func: &FunctionItem) -> Result<(), SemanticError> { // 创建新的作用域 - let old_env = std::mem::replace(self.type_cx, self.type_cx.clone()); - - // 获取函数返回类型 - let return_ty = if let Some(ty_expr) = &func.return_ty { - self.type_from_type_expr(ty_expr)? - } else { - Type::Any - }; + self.symbol_table.enter_scope(); // 添加参数到环境 for param in &func.params { - let param_ty = match ¶m.ty { - Some(ty) => self.type_from_type_expr(ty)?, - None => Type::Any, - }; - - self.type_cx.set_type(param.name.clone(), param_ty); + self.symbol_table.insert(¶m.name, ()); } - // 保存当前函数的返回类型,用于检查return语句 - let old_return_type = self.current_function_return_type.replace(return_ty.clone()); - // 分析函数体 - self.analyze_block(&mut func.body)?; + self.analyze_block(&func.body)?; // 恢复环境 - let _ = std::mem::replace(self.type_cx, old_env); - self.current_function_return_type = old_return_type; + self.symbol_table.leave_scope(); Ok(()) } - fn analyze_struct_item(&mut self, _item: &mut StructItem) -> Result<(), CompileError> { + fn analyze_struct_item(&mut self, _item: &StructItem) -> Result<(), SemanticError> { Ok(()) } /// 分析表达式并推断类型 - fn analyze_expression(&mut self, expr: &mut ExpressionNode) -> Result { - let ty = match &mut expr.node { - Expression::Literal(lit) => self.analyze_literal_expression(lit)?, - Expression::Identifier(ident) => self.anlyze_identifier_expression(ident)?, - Expression::Binary(expr) => self.analyze_binary_expression(expr)?, - Expression::Prefix(expr) => self.analyze_prefix_expression(expr)?, - Expression::Call(expr) => self.analyze_call_expression(expr)?, - Expression::Array(expr) => self.analyze_array_expression(expr)?, - Expression::Map(expr) => self.analyze_map_expression(expr)?, - Expression::IndexGet(expr) => self.analyze_index_get_expression(expr)?, - Expression::IndexSet(expr) => self.analyze_index_set_expression(expr)?, - Expression::PropertyGet(expr) => self.analyze_property_get_expression(expr)?, - Expression::PropertySet(expr) => self.analyze_property_set_expression(expr)?, - Expression::Assign(expr) => self.analyze_assign_expression(expr)?, - Expression::Range(expr) => self.analyze_range_expression(expr)?, - Expression::Slice(expr) => self.analyze_slice_expression(expr)?, - Expression::Try(expr) => self.analyze_try_expression(expr)?, - Expression::Await(expr) => self.analyze_await_expression(expr)?, - Expression::CallMethod(expr) => self.analyze_call_method_expression(expr)?, + fn analyze_expression(&mut self, expr: &ExpressionNode) -> Result<(), SemanticError> { + let ret = match &expr.node { + Expression::Identifier(ident) => self.anlyze_identifier_expression(ident), + Expression::Binary(expr) => self.analyze_binary_expression(expr), + Expression::Prefix(expr) => self.analyze_prefix_expression(expr), + Expression::Call(expr) => self.analyze_call_expression(expr), + Expression::Array(expr) => self.analyze_array_expression(expr), + Expression::Map(expr) => self.analyze_map_expression(expr), + Expression::IndexGet(expr) => self.analyze_index_get_expression(expr), + Expression::IndexSet(expr) => self.analyze_index_set_expression(expr), + Expression::PropertyGet(expr) => self.analyze_property_get_expression(expr), + Expression::PropertySet(expr) => self.analyze_property_set_expression(expr), + Expression::Assign(expr) => self.analyze_assign_expression(expr), + Expression::Range(expr) => self.analyze_range_expression(expr), + Expression::Slice(expr) => self.analyze_slice_expression(expr), + Expression::Try(expr) => self.analyze_try_expression(expr), + Expression::Await(expr) => self.analyze_await_expression(expr), + Expression::CallMethod(expr) => self.analyze_call_method_expression(expr), _ => { // 处理其他未实现的表达式类型 - Type::Any + Ok(()) } }; - expr.ty = ty.clone(); - - Ok(ty) - } - - fn analyze_literal_expression( - &mut self, - lit: &mut LiteralExpression, - ) -> Result { - Ok(match lit { - LiteralExpression::Null => Type::Any, - LiteralExpression::Boolean(_) => Type::Boolean, - LiteralExpression::Integer(_) => Type::Integer, - LiteralExpression::Float(_) => Type::Float, - LiteralExpression::Char(_) => Type::Char, - LiteralExpression::String(_) => Type::String, + ret.map_err(|err| { + if err.span.is_zero() { + err.with_span(expr.span) + } else { + err + } }) } fn anlyze_identifier_expression( &mut self, - ident: &mut IdentifierExpression, - ) -> Result { - if let Some(ty) = self.type_cx.get_type(&ident.0) { - return Ok(ty.clone()); + ident: &IdentifierExpression, + ) -> Result<(), SemanticError> { + if self.symbol_table.lookup(ident.name()).is_none() + && self.type_cx.get_function_def(ident.name()).is_none() + { + return Err(ErrKind::UndefinedVariable(ident.name().to_string()).into()); } - Err(CompileError::UndefinedVariable { - name: ident.0.clone(), - }) + Ok(()) } - fn analyze_binary_expression( - &mut self, - expr: &mut BinaryExpression, - ) -> Result { - let lhs_ty = self.analyze_expression(expr.lhs.as_mut())?; - let rhs_ty = self.analyze_expression(expr.rhs.as_mut())?; - - if lhs_ty == Type::Any || rhs_ty == Type::Any { - return Ok(Type::Any); // Object类型可以和任何类型比较 - } + fn analyze_binary_expression(&mut self, expr: &BinaryExpression) -> Result<(), SemanticError> { + self.analyze_expression(&expr.lhs)?; + self.analyze_expression(&expr.rhs)?; - match expr.op { - BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem => { - if (!lhs_ty.is_numeric() || rhs_ty != lhs_ty) - && lhs_ty != Type::String - && lhs_ty != Type::Char - { - return Err(CompileError::type_mismatch( - Type::Integer, - lhs_ty, - expr.lhs.span(), - )); - } - } - BinOp::LogicAnd | BinOp::LogicOr => { - if lhs_ty != Type::Boolean || rhs_ty != Type::Boolean { - return Err(CompileError::type_mismatch( - Type::Boolean, - lhs_ty, - expr.lhs.span(), - )); - } - } - _ => {} - } - - Ok(match expr.op { - BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem => lhs_ty, - BinOp::LogicAnd | BinOp::LogicOr => Type::Boolean, - BinOp::Less - | BinOp::LessEqual - | BinOp::Greater - | BinOp::GreaterEqual - | BinOp::Equal - | BinOp::NotEqual => Type::Boolean, - _ => Type::Any, - }) + Ok(()) } - fn analyze_prefix_expression( - &mut self, - expr: &mut PrefixExpression, - ) -> Result { - let rhs_ty = self.analyze_expression(expr.rhs.as_mut())?; - - match expr.op { - PrefixOp::Neg => { - if !(rhs_ty.is_numeric() || rhs_ty == Type::Any) { - return Err(CompileError::type_mismatch( - Type::Integer, - rhs_ty, - expr.rhs.span(), - )); - } - } - PrefixOp::Not => { - if !(rhs_ty.is_boolean() || rhs_ty == Type::Any) { - return Err(CompileError::type_mismatch( - Type::Boolean, - rhs_ty, - expr.rhs.span(), - )); - } - } - } + fn analyze_prefix_expression(&mut self, expr: &PrefixExpression) -> Result<(), SemanticError> { + self.analyze_expression(&expr.rhs)?; - Ok(match expr.op { - PrefixOp::Neg => rhs_ty, - PrefixOp::Not => Type::Boolean, - }) + Ok(()) } - fn analyze_call_expression(&mut self, expr: &mut CallExpression) -> Result { - let func_ty = self.analyze_expression(expr.func.as_mut())?; + fn analyze_call_expression(&mut self, expr: &CallExpression) -> Result<(), SemanticError> { + self.analyze_expression(&expr.func)?; - if func_ty == Type::Any { - return Ok(Type::Any); + for arg in expr.args.iter() { + self.analyze_expression(arg)?; } - if let Type::Decl(Declaration::Function(FunctionDeclaration { - name, - params, - return_type, - })) = &func_ty - { - if params.len() != expr.args.len() { - return Err(CompileError::ArgumentCountMismatch { - expected: params.len(), - actual: expr.args.len(), - }); - } - - for (arg, param_ty) in expr.args.iter_mut().zip(params.iter()) { - let arg_ty = self.analyze_expression(arg)?; - if let Some(ty) = ¶m_ty.1 { - if arg_ty != *ty && arg_ty != Type::Any { - return Err(CompileError::type_mismatch(ty.clone(), arg_ty, arg.span())); - } - } - } - - Ok(return_type - .as_ref() - .map(|t| *t.clone()) - .unwrap_or(Type::Any)) - } else { - Err(CompileError::NotCallable { - ty: func_ty, - span: expr.func.span(), - }) - } + Ok(()) } - fn analyze_array_expression( - &mut self, - expr: &mut ArrayExpression, - ) -> Result { - if expr.0.is_empty() { - return Ok(Type::Array(Box::new(Type::Any))); - } - - let mut elem_types = Vec::with_capacity(expr.0.len()); - + fn analyze_array_expression(&mut self, expr: &ArrayExpression) -> Result<(), SemanticError> { // 为每个元素创建临时变量并分析类型 - for elem in expr.0.iter_mut() { - let elem_ty = self.analyze_expression(elem)?; - elem_types.push(elem_ty); - } - - // 检查数组元素是否类型一致 - let first_ty = &elem_types[0]; - for (i, elem) in expr.0.iter().enumerate() { - if &elem_types[i] != first_ty && elem_types[i] != Type::Any { - return Err(CompileError::type_mismatch( - first_ty.clone(), - elem_types[i].clone(), - elem.span(), - )); - } + for elem in expr.0.iter() { + self.analyze_expression(elem)?; } - Ok(Type::Array(Box::new(first_ty.clone()))) + Ok(()) } - fn analyze_map_expression(&mut self, _expr: &mut MapExpression) -> Result { - Ok(Type::Map(Box::new(Type::Any))) + fn analyze_map_expression(&mut self, expr: &MapExpression) -> Result<(), SemanticError> { + for (key, value) in expr.0.iter() { + self.analyze_expression(key)?; + self.analyze_expression(value)?; + } + + Ok(()) } fn analyze_index_get_expression( &mut self, - expr: &mut IndexGetExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - let _index_ty = self.analyze_expression(expr.index.as_mut())?; - - // 根据集合类型确定返回类型 - match object_ty { - Type::Array(ty) => Ok(*ty.clone()), - Type::Map(value_ty) => Ok(*value_ty.clone()), - Type::Any => Ok(Type::Any), - _ => { - // 其他不支持的集合类型 - Err(CompileError::InvalidOperation { - message: format!("Cannot index non-array and non-map type({object_ty:?})"), - }) - } - } + expr: &IndexGetExpression, + ) -> Result<(), SemanticError> { + self.analyze_expression(&expr.object)?; + self.analyze_expression(&expr.index)?; + + Ok(()) } fn analyze_index_set_expression( &mut self, - expr: &mut IndexSetExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - let value_ty = self.analyze_expression(expr.value.as_mut())?; - // index 不需要分析 - - match object_ty { - Type::Array(ty) => { - if *ty != value_ty && value_ty != Type::Any { - return Err(CompileError::type_mismatch( - *ty, - value_ty, - expr.value.span(), - )); - } - } - Type::Map(_) | Type::Any => { - // 对于Map,Any对象,允许设置任何属性 - } - _ => { - // 其他不支持的集合类型 - return Err(CompileError::InvalidOperation { - message: "Cannot index non-array and non-map type".to_string(), - }); - } - } + expr: &IndexSetExpression, + ) -> Result<(), SemanticError> { + self.analyze_expression(&expr.object)?; + self.analyze_expression(&expr.value)?; - Ok(Type::Any) + Ok(()) } fn analyze_property_get_expression( &mut self, - expr: &mut PropertyGetExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - - match object_ty { - Type::Map(value_ty) if *value_ty == Type::String => Ok(*value_ty), - Type::Any => { - // 对于通用对象,允许访问任何属性 - Ok(Type::Any) - } - _ => { - // 其他不支持的集合类型 - Err(CompileError::InvalidOperation { - message: "Cannot access property on non-object type".to_string(), - }) - } - } + expr: &PropertyGetExpression, + ) -> Result<(), SemanticError> { + self.analyze_expression(&expr.object)?; + Ok(()) } fn analyze_property_set_expression( &mut self, - expr: &mut PropertySetExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - let _property_ty = self.analyze_expression(expr.value.as_mut())?; - - match object_ty { - Type::Map(value_ty) if *value_ty == Type::String => Ok(Type::Any), - Type::Any => Ok(Type::Any), // 对于Any对象,允许设置任何属性 - _ => { - // 其他不支持的集合类型 - Err(CompileError::InvalidOperation { - message: "Cannot access property on non-object type".to_string(), - }) - } - } + expr: &PropertySetExpression, + ) -> Result<(), SemanticError> { + self.analyze_expression(&expr.object)?; + self.analyze_expression(&expr.value)?; + + Ok(()) } - fn analyze_assign_expression( - &mut self, - expr: &mut AssignExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - let value_ty = self.analyze_expression(expr.value.as_mut())?; - - // 检查赋值左右类型是否兼容 - if object_ty != Type::Any && value_ty != Type::Any && object_ty != value_ty { - return Err(CompileError::type_mismatch( - object_ty, - value_ty, - expr.value.span(), - )); - } + fn analyze_assign_expression(&mut self, expr: &AssignExpression) -> Result<(), SemanticError> { + self.analyze_expression(&expr.object)?; + self.analyze_expression(&expr.value)?; - Ok(object_ty) + Ok(()) } - fn analyze_range_expression( - &mut self, - expr: &mut RangeExpression, - ) -> Result { - if let Some(ref mut begin_expr) = expr.begin { - let begin_ty = self.analyze_expression(begin_expr)?; - if begin_ty != Type::Integer && begin_ty != Type::Any { - return Err(CompileError::type_mismatch( - Type::Integer, - begin_ty, - begin_expr.span(), - )); - } + fn analyze_range_expression(&mut self, expr: &RangeExpression) -> Result<(), SemanticError> { + if let Some(ref begin_expr) = expr.begin { + self.analyze_expression(begin_expr)?; } - if let Some(ref mut end_expr) = expr.end { - let end_ty = self.analyze_expression(end_expr)?; - if end_ty != Type::Integer && end_ty != Type::Any { - return Err(CompileError::type_mismatch( - Type::Integer, - end_ty, - end_expr.span(), - )); - } + if let Some(ref end_expr) = expr.end { + self.analyze_expression(end_expr)? } - Ok(Type::Range) + Ok(()) } - fn analyze_slice_expression( - &mut self, - expr: &mut SliceExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; - self.analyze_range_expression(&mut expr.range.node)?; + fn analyze_slice_expression(&mut self, expr: &SliceExpression) -> Result<(), SemanticError> { + self.analyze_expression(&expr.object)?; + self.analyze_range_expression(&expr.range.node)?; - // 切片表达式的类型与原始对象类型相同 - Ok(object_ty) + Ok(()) } - fn analyze_try_expression(&mut self, expr: &mut ExpressionNode) -> Result { - let expr_ty = self.analyze_expression(expr)?; - if expr_ty != Type::Any { - return Err(CompileError::InvalidOperation { - message: "Cannot try non-result type".to_string(), - }); - } + fn analyze_try_expression(&mut self, expr: &ExpressionNode) -> Result<(), SemanticError> { + self.analyze_expression(expr)?; - // Try表达式的类型与内部表达式类型相同 - Ok(expr_ty) + Ok(()) } - fn analyze_await_expression( - &mut self, - expr: &mut ExpressionNode, - ) -> Result { - let expr_ty = self.analyze_expression(expr)?; - if expr_ty != Type::Any { - return Err(CompileError::InvalidOperation { - message: "Cannot await non-promise type".to_string(), - }); - } + fn analyze_await_expression(&mut self, expr: &ExpressionNode) -> Result<(), SemanticError> { + self.analyze_expression(expr)?; - // Await表达式的类型与内部表达式类型相同 - Ok(expr_ty) + Ok(()) } fn analyze_call_method_expression( &mut self, - expr: &mut CallMethodExpression, - ) -> Result { - let object_ty = self.analyze_expression(expr.object.as_mut())?; + expr: &CallMethodExpression, + ) -> Result<(), SemanticError> { + self.analyze_expression(&expr.object)?; // 分析方法参数 - for arg in &mut expr.args { + for arg in &expr.args { self.analyze_expression(arg)?; } - // 方法调用的结果类型取决于方法实现,这里简单处理为Any类型 - Ok(Type::Any) - } - - /// 从类型表达式转换为Type - fn type_from_type_expr(&mut self, type_expr: &TypeExpression) -> Result { - let ty = self.type_cx.resolve_type_decl(type_expr); - if ty == Type::Unknown { - Err(CompileError::UnknownType { - name: format!("{type_expr:?}"), - }) - } else { - Ok(ty) - } + Ok(()) } } diff --git a/src/compiler/symbol.rs b/src/compiler/symbol.rs new file mode 100644 index 0000000..8d47fa8 --- /dev/null +++ b/src/compiler/symbol.rs @@ -0,0 +1,68 @@ +use std::collections::HashMap; + +#[derive(Debug)] +pub struct SymbolTable { + scopes: Vec>, +} + +impl SymbolTable { + pub fn new() -> Self { + SymbolTable:: { + scopes: vec![Scope::::new()], + } + } + + pub fn lookup(&self, name: impl AsRef) -> Option<&T> { + for scope in self.scopes.iter().rev() { + if let Some(ty) = scope.variables.get(name.as_ref()) { + return Some(ty); + } + } + None + } + + pub fn insert(&mut self, name: impl Into, value: T) { + self.scopes + .last_mut() + .unwrap() + .variables + .insert(name.into(), value); + } + + pub fn enter_scope(&mut self) { + self.scopes.push(Scope::::new()); + } + + pub fn leave_scope(&mut self) { + self.scopes.pop(); + } +} + +impl Clone for SymbolTable { + fn clone(&self) -> Self { + SymbolTable { + scopes: self.scopes.clone(), + } + } +} + +#[derive(Debug)] +struct Scope { + variables: HashMap, +} + +impl Scope { + fn new() -> Self { + Scope { + variables: HashMap::new(), + } + } +} + +impl Clone for Scope { + fn clone(&self) -> Self { + Scope { + variables: self.variables.clone(), + } + } +} diff --git a/src/compiler/typing.rs b/src/compiler/typing.rs index a7ccdd3..5e18a20 100644 --- a/src/compiler/typing.rs +++ b/src/compiler/typing.rs @@ -1,107 +1,330 @@ use std::collections::HashMap; -use crate::Environment; + +use crate::{Environment, compiler::symbol::SymbolTable}; use super::ast::syntax::*; +#[derive(Debug, Clone)] +pub struct TypeError { + pub span: Span, + pub kind: ErrKind, +} + +impl TypeError { + pub fn new(span: Span, kind: ErrKind) -> TypeError { + TypeError { span, kind } + } + + pub fn with_span(mut self, span: Span) -> Self { + if !span.is_zero() { + self.span = span; + } + self + } +} + +impl From for TypeError { + fn from(value: ErrKind) -> Self { + TypeError { + span: Span::new(0, 0), + kind: value, + } + } +} + +#[derive(Debug, Clone)] +pub enum ErrKind { + Message(String), + UnresovledType(String), + DuplicateName(String), + TypeMismatch { expected: Type, actual: Type }, +} + +impl ErrKind { + pub fn with_span(self, span: Span) -> TypeError { + TypeError { span, kind: self } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct TypeId(usize); + +#[derive(Debug, Clone, PartialEq, Default)] +pub enum Type { + Boolean, + Byte, + Integer, + Float, + Char, + String, + Array, + Tuple, + Range, + Enum(TypeId), + Struct(TypeId), + Function(Box), + Any, + #[default] + Unknown, +} + +impl Type { + pub fn is_any(&self) -> bool { + matches!(self, Type::Any) + } + pub fn is_boolean(&self) -> bool { + matches!(self, Type::Boolean) + } + + pub fn is_numeric(&self) -> bool { + matches!(self, Type::Byte | Type::Integer | Type::Float) + } + + pub fn is_string(&self) -> bool { + matches!(self, Type::String) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum TypeDef { + Struct(StructDef), + Enum(EnumDef), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct FunctionDef { + pub name: String, + pub params: Vec<(String, Option)>, + pub return_type: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct EnumDef { + pub name: String, + pub variants: Vec<(String, Option)>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct StructDef { + pub name: String, + pub fields: HashMap, +} + #[derive(Debug, Clone)] pub struct TypeContext { - type_decls: HashMap, - type_env: HashMap, - // 新增:用于缓存已解析的类型声明 - resolved_types: HashMap, + type_defs: HashMap, + name_to_id: HashMap, + + functions: HashMap>, + + next_id: usize, } impl TypeContext { pub fn new() -> Self { - TypeContext { - type_decls: HashMap::new(), - type_env: HashMap::new(), - resolved_types: HashMap::new(), // 初始化缓存 + Self { + type_defs: HashMap::new(), + name_to_id: HashMap::new(), + functions: HashMap::new(), + next_id: 0, } } - pub fn add_type_decl(&mut self, name: String, decl: Declaration) { - self.type_decls.insert(name.clone(), decl.clone()); + pub fn functions(&self) -> impl Iterator { + self.functions.values().map(|func| &**func) + } - if matches!(&decl, &Declaration::Function(_)) { - self.type_env.insert(name.to_string(), Type::Decl(decl)); - } + pub fn type_is_defined(&self, name: &str) -> bool { + self.name_to_id.contains_key(name) || self.functions.contains_key(name) } - pub fn get_type_decl(&self, name: &str) -> Option<&Declaration> { - self.type_decls.get(name) + pub fn get_type_id(&self, name: &str) -> Option { + self.name_to_id.get(name).copied() } - pub fn get_type(&self, name: &str) -> Option<&Type> { - self.type_env.get(name) + pub fn get_type(&self, name: &str) -> Option { + match self.name_to_id.get(name) { + Some(id) => match self.type_defs.get(id) { + Some(TypeDef::Struct(_)) => Some(Type::Struct(*id)), + Some(TypeDef::Enum(_)) => Some(Type::Enum(*id)), + None => panic!("Type not found"), + }, + None => self + .functions + .get(name) + .map(|func| Type::Function(func.clone())), + } } - pub fn set_type(&mut self, name: String, ty: Type) { - self.type_env.insert(name, ty); + pub fn get_type_def(&self, name: &str) -> Option<&TypeDef> { + match self.name_to_id.get(name) { + Some(id) => self.type_defs.get(id), + None => None, + } } - pub fn function_decls(&self) -> impl Iterator { - self.type_decls.values().filter(|decl| decl.is_function()) + pub fn get_function_def(&self, name: &str) -> Option<&FunctionDef> { + self.functions.get(name).map(|v| &**v) } - pub fn process_env(&mut self, env: &Environment) { - for (name, _value) in env.symbols.iter() { - self.set_type(name.clone(), Type::Any); + pub fn check_type_def(&mut self, stmts: &[StatementNode]) -> Result<(), TypeError> { + // round 1, decl types + for stmt in stmts { + match &stmt.node { + Statement::Item(ItemStatement::Struct(StructItem { name, .. })) => { + self.decl_type( + name.clone(), + TypeDef::Struct(StructDef { + name: name.clone(), + fields: HashMap::new(), + }), + ) + .map_err(|err| err.with_span(stmt.span))?; + } + Statement::Item(ItemStatement::Enum(item)) => { + self.decl_type( + item.name.clone(), + TypeDef::Enum(EnumDef { + name: item.name.clone(), + variants: Vec::new(), + }), + ) + .map_err(|err| err.with_span(stmt.span()))?; + } + _ => {} + } } - } - pub fn analyze_type_decl(&mut self, stmts: &[StatementNode]) { + // round 2, resolve types for stmt in stmts { match &stmt.node { - Statement::Item(ItemStatement::Fn(func)) => { - let FunctionItem { - name, - params, - return_ty, - body, - } = func; + Statement::Item(ItemStatement::Function(func)) => { + self.check_function_item(func)?; + } + Statement::Item(ItemStatement::Struct(item)) => { + self.check_struct_item(item)?; + } + Statement::Item(ItemStatement::Enum(item)) => { + self.check_enum_item(item)?; + } + _ => { + continue; + } + } + } + + Ok(()) + } - let mut param_types = Vec::new(); + fn check_function_item(&mut self, item: &FunctionItem) -> Result<(), TypeError> { + let FunctionItem { + name, + params, + return_ty, + .. + } = item; - for param in params { - let ty = param.ty.clone().map(|t| self.resolve_type_decl(&t)); + let mut func = FunctionDef { + name: name.clone(), + params: Vec::new(), + return_type: return_ty + .as_ref() + .map(|ty| self.resolve_type(ty)) + .transpose()?, + }; - param_types.push((param.name.clone(), ty)); - } + for param in params { + let param_type = param + .ty + .as_ref() + .map(|ty| self.resolve_type(ty)) + .transpose()?; + func.params.push((param.name.clone(), param_type)); + } + + self.functions.insert(name.clone(), Box::new(func)); - let return_ty = return_ty - .as_ref() - .map(|t| Box::new(self.resolve_type_decl(t))); + Ok(()) + } - let func_decl = FunctionDeclaration { - name: name.clone(), - params: param_types, - return_type: return_ty, - }; + fn check_struct_item(&mut self, item: &StructItem) -> Result { + let StructItem { name, fields } = item; - self.add_type_decl(name.clone(), Declaration::Function(func_decl)); - } - Statement::Item(ItemStatement::Struct(item)) => { - let StructItem { name, fields } = item; - let struct_decl = StructDeclaration { - name: name.clone(), - fields: fields - .iter() - .map(|field| (field.name.clone(), self.resolve_type_decl(&field.ty))) - .collect(), - }; - - self.add_type_decl(name.clone(), Declaration::Struct(struct_decl)); - } - Statement::Item(ItemStatement::Enum(EnumItem { .. })) => {} - _ => {} + let type_id = self.name_to_id.get(name).unwrap(); + + let fields: HashMap = fields + .iter() + .map(|field| { + self.resolve_type(&field.ty) + .map(|ty| (field.name.clone(), ty)) + }) + .collect::>()?; + + match self.type_defs.get_mut(type_id) { + Some(TypeDef::Struct(struct_def)) => { + // update the struct definition + struct_def.fields = fields; + } + _ => { + return Err(ErrKind::Message(format!("Type {} is not a struct", name)).into()); + } + } + + Ok(*type_id) + } + + fn check_enum_item(&mut self, item: &EnumItem) -> Result { + let EnumItem { name, variants } = item; + + let type_id = self.name_to_id.get(name).unwrap(); + + let mut enum_variants = Vec::new(); + + for variant in variants { + let EnumVariant { name, variant } = variant; + + let variant_type = variant + .as_ref() + .map(|ty| self.resolve_type(ty)) + .transpose()?; + + enum_variants.push((name.to_string(), variant_type)); + } + + match self.type_defs.get_mut(type_id) { + Some(TypeDef::Enum(enum_def)) => { + // update the enum definition + enum_def.variants = enum_variants; + } + _ => { + return Err(ErrKind::Message(format!("Type {} is not an enum", name)).into()); } } + + Ok(*type_id) + } + + fn decl_type(&mut self, name: String, ty: TypeDef) -> Result { + if self.name_to_id.contains_key(&name) { + return Err(ErrKind::DuplicateName(name).into()); + } + + let id = self.next_id(); + self.type_defs.insert(id, ty); + self.name_to_id.insert(name, id); + + Ok(id) } - // 新增:递归解析类型声明 - fn resolve_type_decl_recursive(&mut self, type_expr: &TypeExpression) -> Type { + fn next_id(&mut self) -> TypeId { + let id = self.next_id; + self.next_id += 1; + TypeId(id) + } + + /// Try to analyze type expressions when possible, otherwise return Type::Unknown. + fn try_resolve_type(&self, type_expr: &TypeExpression) -> Type { match type_expr { TypeExpression::Any => Type::Any, TypeExpression::Boolean => Type::Boolean, @@ -110,55 +333,523 @@ impl TypeContext { TypeExpression::Float => Type::Float, TypeExpression::Char => Type::Char, TypeExpression::String => Type::String, - TypeExpression::Tuple(types) => { - let types = types.iter().map(|ty| self.resolve_type_decl(ty)).collect(); - Type::Tuple(types) + TypeExpression::Array(_) => Type::Array, + TypeExpression::Tuple(_) => Type::Tuple, + TypeExpression::UserDefined(name) => match self.name_to_id.get(name).cloned() { + Some(id) => match self.type_defs.get(&id) { + Some(TypeDef::Struct(_)) => Type::Struct(id), + Some(TypeDef::Enum(_)) => Type::Enum(id), + _ => panic!("Invalid type"), + }, + None => Type::Unknown, + }, + _ => Type::Unknown, + } + } + + pub fn resolve_type(&self, type_expr: &TypeExpression) -> Result { + match self.try_resolve_type(type_expr) { + Type::Unknown => Err(ErrKind::UnresovledType(format!("{:?}", type_expr)).into()), + ty => Ok(ty), + } + } +} + +pub struct TypeChecker<'a> { + type_cx: &'a TypeContext, + current_function_return_type: Option, + symbols: SymbolTable, +} + +impl<'a> TypeChecker<'a> { + pub fn new(type_cx: &'a TypeContext) -> Self { + TypeChecker { + type_cx, + current_function_return_type: None, + symbols: SymbolTable::new(), + } + } + + pub fn check_program(&mut self, program: &Program, env: &Environment) -> Result<(), TypeError> { + // all env is any + for name in env.keys() { + self.symbols.insert(name.to_string(), Type::Any); + } + // insert function type + for (name, func) in self.type_cx.functions.iter() { + self.symbols + .insert(func.name.clone(), Type::Function(func.clone())); + } + + for item in &program.stmts { + self.check_statement(item)?; + } + Ok(()) + } + + fn check_statement(&mut self, stmt: &StatementNode) -> Result<(), TypeError> { + match &stmt.node { + Statement::Let(let_stmt) => self.check_let_statement(let_stmt), + Statement::Block(block) => self.check_block_statement(block), + Statement::If(if_stmt) => self.check_if_statement(if_stmt), + Statement::While(while_stmt) => self.check_while_statement(while_stmt), + Statement::For(for_stmt) => self.check_for_statement(for_stmt), + Statement::Loop(loop_stmt) => self.check_loop_statement(loop_stmt), + Statement::Return(return_stmt) => self.check_return_statement(return_stmt), + Statement::Expression(expr) => self.check_expression(expr).map(|_| ()), + Statement::Empty => Ok(()), + Statement::Break => Ok(()), + Statement::Continue => Ok(()), + Statement::Item(item_stmt) => self.check_item_statement(item_stmt), + } + } + + // 新增方法:检查块语句 + fn check_block_statement(&mut self, block: &BlockStatement) -> Result<(), TypeError> { + self.symbols.enter_scope(); + + for stmt in &block.0 { + self.check_statement(stmt)?; + } + + self.symbols.leave_scope(); + Ok(()) + } + + // 新增方法:检查条件语句 + fn check_if_statement(&mut self, if_stmt: &IfStatement) -> Result<(), TypeError> { + let condition_type = self.check_expression(&if_stmt.condition)?; + + if condition_type != Type::Boolean && condition_type != Type::Any { + return Err(ErrKind::TypeMismatch { + expected: Type::Boolean, + actual: condition_type, + } + .with_span(if_stmt.condition.span)); + } + + self.check_block_statement(&if_stmt.then_branch)?; + if let Some(else_branch) = &if_stmt.else_branch { + self.check_block_statement(else_branch)?; + } + Ok(()) + } + + // 新增方法:检查循环语句 + fn check_while_statement(&mut self, while_stmt: &WhileStatement) -> Result<(), TypeError> { + let condition_type = self.check_expression(&while_stmt.condition)?; + + if condition_type != Type::Boolean && condition_type != Type::Any { + return Err(ErrKind::TypeMismatch { + expected: Type::Boolean, + actual: condition_type, } - TypeExpression::Array(ty) => { - let ty = self.resolve_type_decl(ty); - Type::Array(Box::new(ty)) + .with_span(while_stmt.condition.span)); + } + + self.check_block_statement(&while_stmt.body)?; + Ok(()) + } + + // 新增方法:检查 for 循环语句 + fn check_for_statement(&mut self, for_stmt: &ForStatement) -> Result<(), TypeError> { + self.check_expression(&for_stmt.iterable)?; + self.check_pattern(&for_stmt.pat)?; + self.check_block_statement(&for_stmt.body)?; + Ok(()) + } + + fn check_pattern(&mut self, pattern: &Pattern) -> Result<(), TypeError> { + match pattern { + Pattern::Identifier(identifier) => { + self.symbols.insert(identifier.name(), Type::Any); + Ok(()) } - TypeExpression::UserDefined(ty) => { - // 检查缓存中是否存在已解析的类型 - if let Some(cached_type) = self.resolved_types.get(ty) { - return cached_type.clone(); + Pattern::Tuple(tuple) => { + for pattern in tuple { + self.check_pattern(pattern)?; } + Ok(()) + } + Pattern::Wildcard => Ok(()), + Pattern::Literal(literal) => Ok(()), + } + } - // 检查是否已经存在于 type_env 中 - if let Some(ty) = self.get_type(ty) { - return ty.clone(); + // 新增方法:检查无限循环语句 + fn check_loop_statement(&mut self, loop_stmt: &LoopStatement) -> Result<(), TypeError> { + self.check_block_statement(&loop_stmt.body)?; + Ok(()) + } + + // 新增方法:检查返回语句 + fn check_return_statement(&mut self, return_stmt: &ReturnStatement) -> Result<(), TypeError> { + if let Some(expr) = &return_stmt.value { + let return_ty = self.check_expression(expr)?; + + if let Some(expected_ty) = &self.current_function_return_type { + if return_ty != *expected_ty { + return Err(ErrKind::TypeMismatch { + expected: expected_ty.clone(), + actual: return_ty, + } + .with_span(expr.span())); } + } + } + Ok(()) + } - // 解析类型声明 - if let Some(decl) = self.get_type_decl(ty) { - let resolved_type = match decl { - Declaration::Function(func_decl) => { - Type::Decl(Declaration::Function(func_decl.clone())) - } - Declaration::Struct(struct_decl) => { - Type::Decl(Declaration::Struct(struct_decl.clone())) - } - Declaration::Enum(enum_decl) => { - Type::Decl(Declaration::Enum(enum_decl.clone())) + // 新增方法:检查项语句 + fn check_item_statement(&mut self, item_stmt: &ItemStatement) -> Result<(), TypeError> { + if let ItemStatement::Function(func) = item_stmt { + self.check_function_item(func)?; + } + + Ok(()) + } + + fn check_function_item(&mut self, func_item: &FunctionItem) -> Result<(), TypeError> { + // new scope + let old_return_type = self.current_function_return_type.clone(); + self.symbols.enter_scope(); + + for param in &func_item.params { + let param_type = match param.ty.as_ref() { + Some(ty) => self.type_cx.resolve_type(ty)?, + None => Type::Any, + }; + + self.symbols.insert(param.name.clone(), param_type); + } + + self.current_function_return_type = func_item + .return_ty + .as_ref() + .map(|ty| self.type_cx.resolve_type(ty)) + .transpose()?; + + self.check_block_statement(&func_item.body)?; + + // restore + self.symbols.leave_scope(); + self.current_function_return_type = old_return_type; + + Ok(()) + } + + fn check_let_statement(&mut self, let_stmt: &LetStatement) -> Result<(), TypeError> { + let decl_ty = let_stmt + .ty + .as_ref() + .map(|ty| self.type_cx.resolve_type(ty)) + .transpose()?; + + let value_ty = let_stmt + .value + .as_ref() + .map(|expr| self.check_expression(expr)) + .transpose()?; + + let ty = match (decl_ty, value_ty) { + (Some(decl_ty), Some(value_ty)) => { + if decl_ty != value_ty { + return Err(ErrKind::TypeMismatch { + expected: decl_ty, + actual: value_ty, + } + .with_span(let_stmt.value.as_ref().unwrap().span)); + } else { + decl_ty + } + } + (Some(decl_ty), None) => decl_ty, + (None, Some(value_ty)) => value_ty, + (None, None) => Type::Any, + }; + + self.symbols.insert(&let_stmt.name, ty); + + Ok(()) + } + + fn check_expression(&mut self, expr: &ExpressionNode) -> Result { + let ret = match &expr.node { + Expression::Literal(lit) => self.check_literal(lit), + Expression::Identifier(id) => self.check_identifier(id), + Expression::Binary(bin) => self.check_binary(bin), + Expression::Prefix(prefix) => self.check_prefix(prefix), + Expression::Call(call) => self.check_call(call), + Expression::Environment(env) => Ok(Type::String), + Expression::Path(path) => self.check_path(path), + Expression::Tuple(tuple) => self.check_tuple(tuple), + Expression::Array(arr) => self.check_array(arr), + Expression::Map(map) => Ok(Type::Any), // 暂定Map类型为Any + Expression::Closure(closure) => self.check_closure(closure), + Expression::Range(range) => Ok(Type::Any), // 暂定Range类型为Any + Expression::Slice(slice) => self.check_slice(slice), + Expression::Assign(assign) => self.check_assign(assign), + Expression::IndexGet(index) => self.check_index_get(index), + Expression::IndexSet(index) => self.check_index_set(index), + Expression::PropertyGet(prop) => self.check_property_get(prop), + Expression::PropertySet(prop) => self.check_property_set(prop), + Expression::CallMethod(call) => self.check_call_method(call), + Expression::StructExpr(struct_) => self.check_struct_expr(struct_), + Expression::Await(expr) => self.check_expression(expr), + Expression::Try(expr) => self.check_expression(expr), + // _ => Err(ErrKind::Message(format!("Unsupported expression: {:?}", expr.node)).into()), + }; + + ret.map_err(|err| err.with_span(expr.span)) + } + + fn check_path(&self, path: &PathExpression) -> Result { + // 路径表达式类型解析逻辑 + Ok(Type::Any) // 暂定返回Any类型 + } + + fn check_tuple(&mut self, tuple: &TupleExpression) -> Result { + // 元组类型解析逻辑 + Ok(Type::Tuple) + } + + fn check_array(&mut self, arr: &ArrayExpression) -> Result { + // 数组类型解析逻辑 + Ok(Type::Array) + } + + fn check_closure(&mut self, closure: &ClosureExpression) -> Result { + // // 闭包类型解析逻辑 + // Ok(Type::Function(Box::new(FunctionDef { + // name: "".to_string(), + // params: vec![], + // return_type: None, + // }))) + + Ok(Type::Any) // 临时返回 Any 类型 + } + + fn check_slice(&mut self, slice: &SliceExpression) -> Result { + // 切片类型解析逻辑 + Ok(Type::Array) + } + + fn check_assign(&mut self, assign: &AssignExpression) -> Result { + // 赋值表达式类型解析逻辑 + self.check_expression(&assign.value) + } + + fn check_index_get(&mut self, index: &IndexGetExpression) -> Result { + // 索引获取类型解析逻辑 + Ok(Type::Any) // 暂定返回Any类型 + } + + fn check_index_set(&mut self, index: &IndexSetExpression) -> Result { + // 索引设置类型解析逻辑 + self.check_expression(&index.value) + } + + fn check_property_get(&mut self, prop: &PropertyGetExpression) -> Result { + // 属性获取类型解析逻辑 + Ok(Type::Any) // 暂定返回Any类型 + } + + fn check_property_set(&mut self, prop: &PropertySetExpression) -> Result { + // 属性设置类型解析逻辑 + self.check_expression(&prop.value) + } + + fn check_call_method(&mut self, call: &CallMethodExpression) -> Result { + // 调用方法类型解析逻辑 + Ok(Type::Any) // 暂定返回Any类型 + } + + fn check_struct_expr(&mut self, struct_expr: &StructExpression) -> Result { + match self.type_cx.get_type_def(struct_expr.name.node()) { + Some(TypeDef::Struct(struct_def)) => { + for field in &struct_expr.fields { + let field_type = self.check_expression(&field.value)?; + if let Some(expected_type) = struct_def.fields.get(&field.name.node) { + if field_type != *expected_type && field_type != Type::Any { + return Err(TypeError::new( + field.value.span(), + ErrKind::Message(format!( + "Expected type {:?} for field {:?}, found {:?}", + expected_type, field.name, field_type + )), + )); } - }; + } + } + + Ok(self.type_cx.get_type(struct_expr.name.node()).unwrap()) + } + Some(ty) => Err(TypeError::new( + struct_expr.name.span(), + ErrKind::Message(format!("Expected struct type, found {:?}", ty)), + )), + None => Err(ErrKind::UnresovledType(struct_expr.name.node().clone()).into()), + } + } + + fn check_literal(&self, lit: &LiteralExpression) -> Result { + match lit { + LiteralExpression::Null => Ok(Type::Any), + LiteralExpression::Boolean(_) => Ok(Type::Boolean), + LiteralExpression::Integer(_) => Ok(Type::Integer), + LiteralExpression::Float(_) => Ok(Type::Float), + LiteralExpression::Char(_) => Ok(Type::Char), + LiteralExpression::String(_) => Ok(Type::String), + } + } + + fn check_identifier(&self, ident: &IdentifierExpression) -> Result { + match self.symbols.lookup(ident.name()) { + Some(ty) => Ok(ty.clone()), + None => Err(ErrKind::Message(format!("Undefined identifier: {}", ident.name())).into()), + } + } + + fn check_binary(&mut self, bin: &BinaryExpression) -> Result { + let lhs_type = self.check_expression(&bin.lhs)?; + let rhs_type = self.check_expression(&bin.rhs)?; + + // Handle Type::Any as compatible with any type + if lhs_type == Type::Any || rhs_type == Type::Any { + return Ok(Type::Any); + } + + match bin.op { + BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem => { + if lhs_type.is_numeric() && lhs_type == rhs_type { + Ok(lhs_type) // Simplified for demonstration + } else { + Err(ErrKind::Message(format!( + "Type mismatch in binary operation: {:?} and {:?}", + lhs_type, rhs_type + )) + .into()) + } + } + BinOp::Equal + | BinOp::NotEqual + | BinOp::Less + | BinOp::LessEqual + | BinOp::Greater + | BinOp::GreaterEqual => { + if lhs_type == rhs_type { + Ok(Type::Boolean) + } else { + Err(ErrKind::Message(format!( + "Type mismatch in comparison: {:?} and {:?}", + lhs_type, rhs_type + )) + .into()) + } + } + BinOp::LogicAnd | BinOp::LogicOr => { + if lhs_type == Type::Boolean && rhs_type == Type::Boolean { + Ok(Type::Boolean) + } else { + Err(ErrKind::Message(format!( + "Type mismatch in logical operation: {:?} and {:?}", + lhs_type, rhs_type + )) + .into()) + } + } + BinOp::Range | BinOp::RangeInclusive => { + if lhs_type == Type::Integer && rhs_type == Type::Integer { + Ok(Type::Range) + } else { + Err(ErrKind::Message(format!( + "Type mismatch in range operation: {:?} and {:?}", + lhs_type, rhs_type + )) + .into()) + } + } + _ => Err(ErrKind::Message(format!("Unsupported binary operator: {:?}", bin.op)).into()), + } + } + + fn check_prefix(&mut self, prefix: &PrefixExpression) -> Result { + let rhs_type = self.check_expression(&prefix.rhs)?; + + // Handle Type::Any as compatible with any type + if rhs_type == Type::Any { + return Ok(Type::Any); + } - // 更新缓存 - self.resolved_types - .insert(ty.clone(), resolved_type.clone()); - self.set_type(ty.clone(), resolved_type.clone()); - resolved_type + match prefix.op { + PrefixOp::Neg => { + if rhs_type.is_numeric() { + Ok(rhs_type) + } else { + Err(ErrKind::Message(format!( + "Cannot apply negation to non-numeric type: {:?}", + rhs_type + )) + .into()) + } + } + PrefixOp::Not => { + if rhs_type == Type::Boolean { + Ok(Type::Boolean) } else { - // 如果无法解析,则返回 UserDefined - Type::UserDefined(ty.clone()) + Err(ErrKind::Message(format!( + "Cannot apply logical NOT to non-boolean type: {:?}", + rhs_type + )) + .into()) } } - _ => Type::Any, } } - // 新增:对外暴露的解析函数 - pub fn resolve_type_decl(&mut self, type_expr: &TypeExpression) -> Type { - self.resolve_type_decl_recursive(type_expr) + fn check_call(&mut self, call: &CallExpression) -> Result { + let func_type = self.check_expression(&call.func)?; + + // Handle Type::Any as compatible with any type + if func_type == Type::Any { + return Ok(Type::Any); + } + + match func_type { + Type::Function(func_def) => { + // Check the number of arguments + if func_def.params.len() != call.args.len() { + return Err(ErrKind::Message(format!( + "Expected {} arguments, but got {}", + func_def.params.len(), + call.args.len() + )) + .into()); + } + + // Check each argument type + for (i, (param_name, param_type)) in func_def.params.iter().enumerate() { + let arg_type = self.check_expression(&call.args[i])?; + if let Some(expected_type) = param_type { + // Handle Type::Any as compatible with any type + if *expected_type != Type::Any && arg_type != *expected_type { + return Err(ErrKind::Message(format!( + "Argument {} has type {:?}, but expected {:?}", + i + 1, + arg_type, + expected_type + )) + .into()); + } + } + } + + // Return the function's return type + Ok(func_def.return_type.unwrap_or(Type::Any)) + } + _ => Err( + ErrKind::Message(format!("Cannot call non-function type: {:?}", func_type)).into(), + ), + } } } diff --git a/src/error.rs b/src/error.rs index bdcdd6f..d58d5cb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,18 +1,18 @@ use crate::{compiler::CompileError, runtime::RuntimeError}; #[derive(Debug)] -pub enum Error { - Compile(CompileError), +pub enum Error<'i> { + Compile(CompileError<'i>), Runtime(RuntimeError), } -impl From for Error { - fn from(error: CompileError) -> Self { +impl<'i> From> for Error<'i> { + fn from(error: CompileError<'i>) -> Self { Error::Compile(error) } } -impl From for Error { +impl From for Error<'_> { fn from(error: RuntimeError) -> Self { Error::Runtime(error) } diff --git a/src/runtime/environment.rs b/src/runtime/environment.rs index e133e51..675dad1 100644 --- a/src/runtime/environment.rs +++ b/src/runtime/environment.rs @@ -26,6 +26,14 @@ impl Environment { } } + pub fn keys(&self) -> impl Iterator { + self.symbols.keys() + } + + pub fn values(&self) -> impl Iterator { + self.symbols.values() + } + pub fn with_variable(mut self, name: impl ToString, value: T) -> Self { self.insert(name, value); self diff --git a/src/runtime/object/map.rs b/src/runtime/object/map.rs index d8f57f1..207cf3a 100644 --- a/src/runtime/object/map.rs +++ b/src/runtime/object/map.rs @@ -27,7 +27,9 @@ where } fn index_set(&mut self, index: &Value, value: ValueRef) -> Result<(), RuntimeError> { - if let (Some(key), Some(value)) = (index.downcast_ref::(), value.value().downcast_ref::()) { + if let (Some(key), Some(value)) = + (index.downcast_ref::(), value.value().downcast_ref::()) + { self.insert(key.clone(), value.clone()); return Ok(()); } diff --git a/tests/test_embed.rs b/tests/test_embed.rs index ae4f864..acc148e 100644 --- a/tests/test_embed.rs +++ b/tests/test_embed.rs @@ -1,10 +1,10 @@ mod utils; use std::sync::Arc; -use evalit::{Environment, Error, Module, Object, RuntimeError, VM, Value, ValueRef, compile}; +use evalit::{Environment, Module, Object, RuntimeError, VM, Value, ValueRef, compile}; use utils::init_logger; -fn run_vm(program: Arc, env: Environment) -> Result { +fn run_vm(program: Arc, env: Environment) -> Result { let mut vm = VM::new(program, env); #[cfg(not(feature = "async"))] let ret = vm.run().unwrap();