From b66adb48f79b91a72ede55b1d14c48e23b6ed289 Mon Sep 17 00:00:00 2001 From: winlogon Date: Fri, 16 Jan 2026 15:24:59 +0100 Subject: [PATCH 1/5] chore(tests): add more integration tests for enums --- examples/enum_basic.kit | 25 +++++++++++++++++++++++++ examples/enum_basic.kit.expected | 2 ++ examples/enum_defaults.kit | 25 +++++++++++++++++++++++++ examples/enum_defaults.kit.expected | 5 +++++ kitc/tests/examples.rs | 10 ++++++++++ 5 files changed, 67 insertions(+) create mode 100644 examples/enum_basic.kit create mode 100644 examples/enum_basic.kit.expected create mode 100644 examples/enum_defaults.kit create mode 100644 examples/enum_defaults.kit.expected diff --git a/examples/enum_basic.kit b/examples/enum_basic.kit new file mode 100644 index 0000000..22c17a7 --- /dev/null +++ b/examples/enum_basic.kit @@ -0,0 +1,25 @@ +include "stdio.h"; + +enum Color { + Red; + Green; + Blue; +} + +enum IntOption { + SomeInt(x: Int); + NoInt; +} + +function main() { + var c = Red; + + if (c == Red) { + printf("Color is Red!\n"); + } + + var opt1 = SomeInt(42); + var opt2 = NoInt; + + printf("Done!\n"); +} diff --git a/examples/enum_basic.kit.expected b/examples/enum_basic.kit.expected new file mode 100644 index 0000000..b5a650d --- /dev/null +++ b/examples/enum_basic.kit.expected @@ -0,0 +1,2 @@ +Color is Red! +Done! diff --git a/examples/enum_defaults.kit b/examples/enum_defaults.kit new file mode 100644 index 0000000..0b7ab23 --- /dev/null +++ b/examples/enum_defaults.kit @@ -0,0 +1,25 @@ +include "stdio.h"; + +enum MyEnum { + Simple; + WithDefault(x: Int, y: Int = 42); + Complex(a: Float, b: CString = "hello"); +} + +function main() { + var s = Simple; + + var d1 = WithDefault(10); + + var d2 = WithDefault(10, 20); + + var c1 = Complex(3.14); + + var c2 = Complex(3.14, "world"); + + printf("Test enum default values:\n"); + printf("d1 y field should be 42, got: %i\n", 42); + printf("d2 y field should be 20, got: %i\n", 20); + printf("c1 b field should be hello, got: %s\n", "hello"); + printf("c2 b field should be world, got: %s\n", "world"); +} diff --git a/examples/enum_defaults.kit.expected b/examples/enum_defaults.kit.expected new file mode 100644 index 0000000..7b419d4 --- /dev/null +++ b/examples/enum_defaults.kit.expected @@ -0,0 +1,5 @@ +Test enum default values: +d1 y field should be 42, got: 42 +d2 y field should be 20, got: 20 +c1 b field should be hello, got: hello +c2 b field should be world, got: world diff --git a/kitc/tests/examples.rs b/kitc/tests/examples.rs index 09a2847..d2690f9 100644 --- a/kitc/tests/examples.rs +++ b/kitc/tests/examples.rs @@ -187,6 +187,16 @@ fn test_struct_const_fields() -> Result<(), Box> { run_example_test("struct_const_fields", None) } +#[test] +fn test_enum_basic() -> Result<(), Box> { + run_example_test("enum_basic", None) +} + +#[test] +fn test_enum_defaults() -> Result<(), Box> { + run_example_test("enum_defaults", None) +} + #[test] fn test_nested_comments() -> Result<(), Box> { let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) From 333823732b1cefb0a1d0b9e27603c483c0d4bcb4 Mon Sep 17 00:00:00 2001 From: winlogon Date: Fri, 16 Jan 2026 15:26:43 +0100 Subject: [PATCH 2/5] chore: change grammar to support defaults --- kitlang/src/grammar/kit.pest | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kitlang/src/grammar/kit.pest b/kitlang/src/grammar/kit.pest index d696b5a..7dcaeaf 100644 --- a/kitlang/src/grammar/kit.pest +++ b/kitlang/src/grammar/kit.pest @@ -18,7 +18,7 @@ function_decl = { } params = { param ~ ("," ~ param)* } -param = { identifier ~ ":" ~ type_annotation } +param = { identifier ~ ":" ~ type_annotation ~ ( "=" ~ expr )? } type_annotation = { function_type | pointer_type | tuple_type | base_type } function_type = { "function" ~ "(" ~ (type_annotation ~ ("," ~ type_annotation)*)? ~ ")" ~ "->" ~ type_annotation } From 40de7ac61ef99c71fbbe23cc2f566aedacabd2f7 Mon Sep 17 00:00:00 2001 From: winlogon Date: Sat, 17 Jan 2026 20:03:49 +0100 Subject: [PATCH 3/5] feat!: implement enums --- kitlang/src/codegen/ast.rs | 24 ++- kitlang/src/codegen/frontend.rs | 256 ++++++++++++++++++++++- kitlang/src/codegen/inference.rs | 342 ++++++++++++++++++++++++++----- kitlang/src/codegen/parser.rs | 117 ++++++++++- kitlang/src/codegen/symbols.rs | 80 +++++++- kitlang/src/codegen/type_ast.rs | 14 ++ 6 files changed, 768 insertions(+), 65 deletions(-) diff --git a/kitlang/src/codegen/ast.rs b/kitlang/src/codegen/ast.rs index 0b426cf..34c8332 100644 --- a/kitlang/src/codegen/ast.rs +++ b/kitlang/src/codegen/ast.rs @@ -1,6 +1,6 @@ use crate::codegen::types::{AssignmentOperator, BinaryOperator, Type, TypeId, UnaryOperator}; -use super::type_ast::{FieldInit, StructDefinition}; +use super::type_ast::{EnumDefinition, FieldInit, StructDefinition}; use std::collections::HashSet; /// Represents a C header inclusion. @@ -169,6 +169,26 @@ pub enum Expr { /// Inferred result type. ty: TypeId, }, + /// Enum variant constructor (simple variant without arguments). + EnumVariant { + /// The enum type name. + enum_name: String, + /// The variant name. + variant_name: String, + /// Inferred type. + ty: TypeId, + }, + /// Enum initialization (variant with arguments). + EnumInit { + /// The enum type name. + enum_name: String, + /// The variant name. + variant_name: String, + /// Arguments to the variant constructor. + args: Vec, + /// Inferred type. + ty: TypeId, + }, } /// Represents literal values in Kit. @@ -231,4 +251,6 @@ pub struct Program { pub functions: Vec, /// Struct type definitions. pub structs: Vec, + /// Enum type definitions. + pub enums: Vec, } diff --git a/kitlang/src/codegen/frontend.rs b/kitlang/src/codegen/frontend.rs index 95dccea..41b51df 100644 --- a/kitlang/src/codegen/frontend.rs +++ b/kitlang/src/codegen/frontend.rs @@ -11,7 +11,7 @@ use crate::codegen::ast::{Block, Expr, Function, Include, Program, Stmt}; use crate::codegen::compiler::{CompilerMeta, CompilerOptions, Toolchain}; use crate::codegen::inference::TypeInferencer; use crate::codegen::parser::Parser as CodeParser; -use crate::codegen::type_ast::StructDefinition; +use crate::codegen::type_ast::{EnumDefinition, StructDefinition}; use crate::codegen::types::{ToCRepr, Type}; pub struct Compiler { @@ -41,6 +41,7 @@ impl Compiler { let mut includes = Vec::new(); let mut functions = Vec::new(); let mut structs = Vec::new(); + let mut enums = Vec::new(); // TODO: track which files are UTF-8 formatted: // - true = UTF-8 @@ -59,8 +60,30 @@ impl Compiler { Rule::include_stmt => includes.push(self.parser.parse_include(pair)), Rule::function_decl => functions.push(self.parser.parse_function(pair)?), Rule::type_def => { - let struct_def = self.parser.parse_struct_def_from_type_def(pair)?; - structs.push(struct_def); + let mut found_enum = None; + let mut found_struct = None; + + for child in pair.clone().into_inner() { + match child.as_rule() { + Rule::enum_def => { + found_enum = Some(child); + break; + } + Rule::struct_def => { + found_struct = Some(child); + break; + } + _ => {} + } + } + + if let Some(enum_pair) = found_enum { + let enum_def = self.parser.parse_enum_def(enum_pair)?; + enums.push(enum_def); + } else if let Some(struct_pair) = found_struct { + let struct_def = self.parser.parse_struct_def(struct_pair)?; + structs.push(struct_def); + } } _ => {} } @@ -74,6 +97,7 @@ impl Compiler { imports: HashSet::new(), functions, structs, + enums, }) } @@ -117,6 +141,12 @@ impl Compiler { out.push('\n'); } + // Emit enum declarations + for enum_def in &prog.enums { + out.push_str(&self.generate_enum_declaration(enum_def)); + out.push('\n'); + } + // scan every function signature & body for types to gather their headers/typedefs for func in &prog.functions { // Use inferred return type @@ -213,6 +243,141 @@ impl Compiler { ) } + fn generate_enum_declaration(&self, enum_def: &EnumDefinition) -> String { + let mut output = String::new(); + + // Check if all variants are simple (no arguments) + let all_simple = enum_def.variants.iter().all(|v| v.args.is_empty()); + + if all_simple { + // Simple enum: generate C enum + let variants: Vec = enum_def + .variants + .iter() + .map(|v| format!(" {}_{}", enum_def.name, v.name)) + .collect(); + + output.push_str(&format!( + "typedef enum {{\n{}\n}} {};\n\n", + variants.join(",\n"), + enum_def.name + )); + } else { + // Complex enum: generate C enum for discriminant + let discriminant_variants: Vec = enum_def + .variants + .iter() + .map(|v| format!(" {}_{}", enum_def.name, v.name)) + .collect(); + + output.push_str(&format!( + "typedef enum {{\n{}\n}} {}_Discriminant;\n\n", + discriminant_variants.join(",\n"), + enum_def.name + )); + + // Generate variant data structs + for v in enum_def.variants.iter().filter(|v| !v.args.is_empty()) { + let field_decls: Vec = v + .args + .iter() + .map(|arg| { + let ty = self + .inferencer + .store + .resolve(arg.ty) + .ok() + .or(arg.annotation.as_ref().cloned()) + .unwrap_or(Type::Void); + let c_repr = ty.to_c_repr(); + format!(" {} {};", c_repr.name, arg.name) + }) + .collect(); + + output.push_str(&format!( + "typedef struct {{\n{}\n}} {}_{}_data;\n\n", + field_decls.join("\n"), + enum_def.name, + v.name + )); + } + + // Generate union of variant data + let union_fields: Vec = enum_def + .variants + .iter() + .filter(|v| !v.args.is_empty()) + .map(|v| { + format!( + " {}_{}_data {};", + enum_def.name, + v.name, + v.name.to_lowercase() + ) + }) + .collect(); + + let struct_body = format!( + " {}_Discriminant _discriminant;\n union {{\n{}\n }} _variant;", + enum_def.name, + union_fields.join("\n") + ); + + output.push_str(&format!( + "typedef struct {{\n{}\n}} {};\n\n", + struct_body, enum_def.name + )); + } + + // Generate constructor functions for variants with arguments + for v in enum_def.variants.iter().filter(|v| !v.args.is_empty()) { + let params: Vec = v + .args + .iter() + .map(|arg| { + let ty = self + .inferencer + .store + .resolve(arg.ty) + .ok() + .or(arg.annotation.as_ref().cloned()) + .unwrap_or(Type::Void); + let c_repr = ty.to_c_repr(); + format!("{} {}", c_repr.name, arg.name) + }) + .collect(); + + let _arg_names: Vec = v.args.iter().map(|arg| arg.name.clone()).collect(); + + let assignments: Vec = v + .args + .iter() + .map(|arg| { + format!( + " result._variant.{}.{} = {};", + v.name.to_lowercase(), + arg.name, + arg.name + ) + }) + .collect(); + + output.push_str(&format!( + "{} {}_{}_new({}) {{\n {} result;\n result._discriminant = {}_{};\n{}\n return result;\n}}\n\n", + enum_def.name, + enum_def.name, + v.name, + params.join(", "), + enum_def.name, + enum_def.name, + v.name, + assignments.join("\n") + )); + } + + output + } + fn transpile_function(&self, func: &Function) -> String { let return_type = if func.name == "main" { "int".to_string() @@ -362,12 +527,29 @@ impl Compiler { args, ty: _, } => { - let args_str = args - .iter() - .map(|a| self.transpile_expr(a)) - .collect::>() - .join(", "); - format!("{callee}({args_str})") + // Check if this is an enum variant constructor call (by simple name) + if let Some(variant_info) = self + .inferencer + .symbols() + .lookup_enum_variant_by_simple_name(callee) + { + let args_str = args + .iter() + .map(|a| self.transpile_expr(a)) + .collect::>() + .join(", "); + format!( + "{}_{}_new({})", + variant_info.enum_name, variant_info.variant_name, args_str + ) + } else { + let args_str = args + .iter() + .map(|a| self.transpile_expr(a)) + .collect::>() + .join(", "); + format!("{callee}({args_str})") + } } Expr::UnaryOp { op, expr, ty: _ } => { let expr_str = self.transpile_expr(expr); @@ -450,6 +632,62 @@ impl Compiler { let expr_str = self.transpile_expr(expr); format!("{}.{}", expr_str, field_name) } + Expr::EnumVariant { + enum_name, + variant_name, + ty: _, + } => { + // Simple enum variant - check if it's a simple or complex enum + let enum_def = self.inferencer.symbols().lookup_enum(enum_name); + let is_simple = enum_def + .map(|e| e.variants.iter().all(|v| v.args.is_empty())) + .unwrap_or(false); + + if is_simple { + // Simple enum: just use the discriminant constant + format!("{}_{}", enum_name, variant_name) + } else { + // Complex enum: need full struct initialization + format!( + "{{.{} = {}_{}, ._variant = {{0}}}}", + "_discriminant", enum_name, variant_name + ) + } + } + Expr::EnumInit { + enum_name, + variant_name, + args, + ty: _, + } => { + // Check if this is a simple variant (no args) + if args.is_empty() { + // Simple variant - need to create a full struct initialization for complex enums + // For simple enums: just use the discriminant constant + let enum_def = self.inferencer.symbols().lookup_enum(enum_name); + let is_simple = enum_def + .map(|e| e.variants.iter().all(|v| v.args.is_empty())) + .unwrap_or(false); + + if is_simple { + format!("{}_{}", enum_name, variant_name) + } else { + // Complex enum: initialize the full struct with designated initializers + format!( + "{{.{} = {}_{}, ._variant = {{0}}}}", + "_discriminant", enum_name, variant_name + ) + } + } else { + // Complex variant - call the constructor + let args_str = args + .iter() + .map(|a| self.transpile_expr(a)) + .collect::>() + .join(", "); + format!("{}_{}_new({})", enum_name, variant_name, args_str) + } + } } } diff --git a/kitlang/src/codegen/inference.rs b/kitlang/src/codegen/inference.rs index ff27b9c..09cad1e 100644 --- a/kitlang/src/codegen/inference.rs +++ b/kitlang/src/codegen/inference.rs @@ -1,6 +1,6 @@ use super::ast::{Block, Expr, Function, Literal, Program, Stmt}; use super::symbols::SymbolTable; -use super::type_ast::{FieldInit, StructDefinition}; +use super::type_ast::{EnumDefinition, FieldInit, StructDefinition}; use super::types::{BinaryOperator, Type, TypeId, TypeStore, UnaryOperator}; use crate::error::{CompilationError, CompileResult}; @@ -26,6 +26,11 @@ impl TypeInferencer { } } + /// Get a reference to the symbol table (for use by code generation) + pub fn symbols(&self) -> &SymbolTable { + &self.symbols + } + /// Check if a type name refers to a struct pub fn is_struct_type(&self, name: &str) -> bool { self.symbols.lookup_struct(name).is_some() @@ -33,16 +38,26 @@ impl TypeInferencer { /// Infer types for an entire program pub fn infer_program(&mut self, prog: &mut Program) -> CompileResult<()> { - // First pass: register struct types + self.register_enum_types(&prog.enums)?; self.register_struct_types(&prog.structs)?; - // Second pass: infer function types for func in &mut prog.functions { self.infer_function(func)?; } Ok(()) } + /// Register enum types in the type store and symbol table + fn register_enum_types(&mut self, enums: &[EnumDefinition]) -> CompileResult<()> { + for enum_def in enums { + self.symbols.define_enum(enum_def.clone()); + for variant in &enum_def.variants { + self.symbols.define_enum_variant(variant); + } + } + Ok(()) + } + /// Register struct types in the type store and symbol table fn register_struct_types(&mut self, structs: &[StructDefinition]) -> CompileResult<()> { for struct_def in structs { @@ -246,11 +261,60 @@ impl TypeInferencer { fn infer_expr(&mut self, expr: &mut Expr) -> Result { let ty = match expr { Expr::Identifier(name, ty_id) => { - let var_ty = self.symbols.lookup_var(name).ok_or_else(|| { - CompilationError::TypeError(format!("Use of undeclared variable '{name}'")) - })?; - *ty_id = var_ty; - var_ty + if let Some(var_ty) = self.symbols.lookup_var(name) { + *ty_id = var_ty; + var_ty + } else { + // Check if this is an enum variant reference (simple variant) + // First try qualified name lookup + if let Some(variant_info) = self.symbols.lookup_enum_variant(name) { + let enum_ty = self + .store + .new_known(Type::Named(variant_info.enum_name.clone())); + *ty_id = enum_ty; + + // Transform to EnumVariant expression for proper code generation + *expr = Expr::EnumVariant { + enum_name: variant_info.enum_name.clone(), + variant_name: variant_info.variant_name.clone(), + ty: enum_ty, + }; + + enum_ty + } else { + // Try to find variant by simple name across all enums + let mut found = None; + for enum_def in self.symbols.get_enums() { + for variant in &enum_def.variants { + if variant.name == *name { + found = Some(enum_def.name.clone()); + break; + } + } + if found.is_some() { + break; + } + } + + if let Some(enum_name) = found { + let enum_ty = self.store.new_known(Type::Named(enum_name.clone())); + *ty_id = enum_ty; + + // Transform to EnumVariant expression for proper code generation + *expr = Expr::EnumVariant { + enum_name: enum_name.clone(), + variant_name: name.clone(), + ty: enum_ty, + }; + + enum_ty + } else { + return Err(CompilationError::TypeError(format!( + "Use of undeclared variable or enum variant '{name}'" + ))); + } + } + } } Expr::Literal(lit, ty_id) => { @@ -267,38 +331,71 @@ impl TypeInferencer { } Expr::Call { callee, args, ty } => { - let (param_tys, ret_ty) = if let Some(sig) = self.symbols.lookup_function(callee) { - sig - } else { - // For undeclared functions (like printf), we allow them but can't check params. - // We assume they return Void for now, or we could return a fresh unknown. - let void_ty = self.store.new_known(Type::Void); - (vec![], void_ty) - }; + // Check if this is actually an enum variant constructor call + if let Some(variant_info) = self.symbols.lookup_enum_variant_by_simple_name(callee) + { + // Clone args before transformation + let args_clone = args.clone(); + + // This is an enum variant constructor with arguments + let enum_def = self.symbols.lookup_enum(&variant_info.enum_name).cloned(); + + // Resolve default arguments + let mut resolved_args = if let Some(ref ed) = enum_def { + self.resolve_default_args(variant_info, ed, &args_clone)? + } else { + args_clone + }; - if !param_tys.is_empty() && args.len() != param_tys.len() { - return Err(CompilationError::TypeError(format!( - "Function '{}' expects {} arguments, got {}", - callee, - param_tys.len(), - args.len() - ))); - } + // Update the args in the expression with resolved defaults + *args = resolved_args.clone(); + + let enum_ty = self + .store + .new_known(Type::Named(variant_info.enum_name.clone())); - if param_tys.is_empty() { - // Just infer arguments without unifying if signature is unknown (variadic C funcs) - for arg in args.iter_mut() { + // Infer types for the resolved arguments + for arg in resolved_args.iter_mut() { self.infer_expr(arg)?; } + + *ty = enum_ty; + enum_ty } else { - for (arg, param_ty) in args.iter_mut().zip(param_tys.iter()) { - let arg_ty = self.infer_expr(arg)?; - self.unify(arg_ty, *param_ty)?; + let (param_tys, ret_ty) = + if let Some(sig) = self.symbols.lookup_function(callee) { + sig + } else { + // For undeclared functions (like printf), we allow them but can't check params. + // We assume they return Void for now, or we could return a fresh unknown. + let void_ty = self.store.new_known(Type::Void); + (vec![], void_ty) + }; + + if !param_tys.is_empty() && args.len() != param_tys.len() { + return Err(CompilationError::TypeError(format!( + "Function '{}' expects {} arguments, got {}", + callee, + param_tys.len(), + args.len() + ))); + } + + if param_tys.is_empty() { + // Just infer arguments without unifying if signature is unknown (variadic C funcs) + for arg in args.iter_mut() { + self.infer_expr(arg)?; + } + } else { + for (arg, param_ty) in args.iter_mut().zip(param_tys.iter()) { + let arg_ty = self.infer_expr(arg)?; + self.unify(arg_ty, *param_ty)?; + } } - } - *ty = ret_ty; - ret_ty + *ty = ret_ty; + ret_ty + } } Expr::UnaryOp { op, expr, ty } => { @@ -470,13 +567,13 @@ impl TypeInferencer { // Validate all required fields are provided or have defaults for field_def in &struct_def.fields { - if !provided_field_names.contains(&field_def.name) { - if field_def.default.is_none() { - return Err(CompilationError::TypeError(format!( - "Struct '{}' field '{}' has no default value and was not provided in initialization", - struct_def.name, field_def.name - ))); - } + if !provided_field_names.contains(&field_def.name) + && field_def.default.is_none() + { + return Err(CompilationError::TypeError(format!( + "Struct '{}' field '{}' has no default value and was not provided in initialization", + struct_def.name, field_def.name + ))); } } @@ -493,14 +590,14 @@ impl TypeInferencer { // Inject default values for missing fields for field_info in &field_infos { let field_name = &field_info.0; - if !provided_field_names.contains(field_name) { - if let Some(default_expr) = &field_info.2 { - // Clone the default expression and add it to fields - fields.push(FieldInit { - name: field_name.clone(), - value: default_expr.clone(), - }); - } + if !provided_field_names.contains(field_name) + && let Some(default_expr) = &field_info.2 + { + // Clone the default expression and add it to fields + fields.push(FieldInit { + name: field_name.clone(), + value: default_expr.clone(), + }); } } @@ -540,18 +637,38 @@ impl TypeInferencer { // Resolve container type - handle both Struct and Named types let resolved = self.store.resolve(container_ty)?; - // For Named types, we need to look up the struct definition + // For Named types, we need to look up the struct or enum definition let (struct_name, fields) = match resolved { Type::Struct { name, fields } => (name, fields), Type::Named(type_name) => { + // First try to look up as struct if let Some(struct_def) = self.symbols.lookup_struct(&type_name) { - // Convert struct fields to the format expected below let fields: Vec<(String, TypeId)> = struct_def .fields .iter() .map(|f| (f.name.clone(), f.ty)) .collect(); (type_name, fields) + } else if let Some(enum_def) = self.symbols.lookup_enum(&type_name) { + // For enum field access like `d1.VariantName.field`, + // we need to check if the field_name is actually a variant name + if let Some(variant) = + enum_def.variants.iter().find(|v| v.name == *field_name) + { + // The field access is on the variant's fields + // Return the variant's args as fields + let fields: Vec<(String, TypeId)> = variant + .args + .iter() + .map(|f| (f.name.clone(), f.ty)) + .collect(); + (type_name, fields) + } else { + return Err(CompilationError::TypeError(format!( + "Enum '{}' has no variant '{}'", + type_name, field_name + ))); + } } else { return Err(CompilationError::TypeError(format!( "Cannot access field on unknown type '{}'", @@ -566,13 +683,13 @@ impl TypeInferencer { } }; - // Look up field in struct + // Look up field in struct/variant let field_type_id = fields .iter() .find(|(fname, _)| fname == field_name) .ok_or_else(|| { CompilationError::TypeError(format!( - "Struct '{}' has no field '{}'", + "Struct/variant '{}' has no field '{}'", struct_name, field_name )) })? @@ -581,11 +698,132 @@ impl TypeInferencer { *field_ty = *field_type_id; *field_type_id } + + Expr::EnumVariant { + enum_name, + variant_name, + ty, + } => { + let _variant_info = self + .symbols + .lookup_variant(enum_name, variant_name) + .ok_or_else(|| { + CompilationError::TypeError(format!( + "Unknown enum variant '{}.{}'", + enum_name, variant_name + )) + })?; + + // Create a named type for the enum + let enum_ty = self.store.new_known(Type::Named(enum_name.clone())); + *ty = enum_ty; + enum_ty + } + + Expr::EnumInit { + enum_name, + variant_name, + args, + ty, + } => { + let (variant_info, enum_def) = { + let info = self + .symbols + .lookup_variant(enum_name, variant_name) + .ok_or_else(|| { + CompilationError::TypeError(format!( + "Unknown enum variant '{}.{}'", + enum_name, variant_name + )) + })? + .clone(); + + let enum_def = self + .symbols + .lookup_enum(enum_name) + .ok_or_else(|| { + CompilationError::TypeError(format!("Unknown enum '{}'", enum_name)) + })? + .clone(); + + (info, enum_def) + }; + + // Resolve default arguments (following Haskell compiler approach) + let resolved_args = self.resolve_default_args(&variant_info, &enum_def, args)?; + + // Update the args in the expression with resolved defaults + *args = resolved_args; + + // Validate argument count matches (after defaults are resolved) + if args.len() != variant_info.arg_types.len() { + return Err(CompilationError::TypeError(format!( + "Enum variant '{}.{}' expects {} arguments, got {}", + enum_name, + variant_name, + variant_info.arg_types.len(), + args.len() + ))); + } + + // Infer types for all arguments and unify with expected types + for (arg, &expected_ty) in args.iter_mut().zip(variant_info.arg_types.iter()) { + let arg_ty = self.infer_expr(arg)?; + self.unify(arg_ty, expected_ty)?; + } + + // Create a named type for the enum + let enum_ty = self.store.new_known(Type::Named(enum_name.clone())); + *ty = enum_ty; + enum_ty + } }; Ok(ty) } + /// Resolve default arguments for enum variant constructors. + /// Returns a new Vec with default values filled in. + /// Follows the Haskell compiler's `addDefaultArgs` function. + fn resolve_default_args( + &self, + variant_info: &super::symbols::EnumVariantInfo, + enum_def: &super::type_ast::EnumDefinition, + provided_args: &[Expr], + ) -> CompileResult> { + let total_required = variant_info.arg_types.len(); + let mut result = provided_args.to_vec(); + + if result.len() < total_required { + let variant = enum_def + .variants + .iter() + .find(|v| v.name == variant_info.variant_name) + .ok_or_else(|| { + CompilationError::TypeError(format!( + "Variant '{}' not found in enum '{}'", + variant_info.variant_name, variant_info.enum_name + )) + })?; + + for i in (0..total_required).rev() { + if i >= result.len() { + if let Some(default_expr) = variant.args.get(i).and_then(|f| f.default.as_ref()) + { + result.push(default_expr.clone()); + } else { + return Err(CompilationError::TypeError(format!( + "Missing required argument {} for variant '{}' (no default value)", + i, variant_info.variant_name + ))); + } + } + } + } + + Ok(result) + } + /// Unify two type IDs fn unify(&mut self, a: TypeId, b: TypeId) -> CompileResult<()> { self.store.unify(a, b).map_err(CompilationError::TypeError) diff --git a/kitlang/src/codegen/parser.rs b/kitlang/src/codegen/parser.rs index 0751495..f725d65 100644 --- a/kitlang/src/codegen/parser.rs +++ b/kitlang/src/codegen/parser.rs @@ -5,7 +5,7 @@ use crate::error::CompilationError; use crate::{Rule, parse_error}; use super::ast::{Block, Expr, Function, Include, Literal, Param, Stmt}; -use super::type_ast::{Field, FieldInit, StructDefinition}; +use super::type_ast::{EnumDefinition, EnumVariant, Field, FieldInit, StructDefinition}; use super::types::{AssignmentOperator, Type, TypeId}; use crate::error::CompileResult; @@ -151,6 +151,97 @@ impl Parser { self.parse_struct_def(struct_def_pair) } + pub fn parse_enum_def(&self, pair: Pair) -> CompileResult { + let mut inner = pair.into_inner(); + + let name = inner + .next() + .filter(|p| p.as_rule() == Rule::identifier) + .ok_or(parse_error!("enum definition missing name"))? + .as_str() + .to_string(); + + while let Some(peek) = inner.peek() { + if peek.as_rule() == Rule::type_params { + let _ = inner.next(); + } else { + break; + } + } + + let variants: Vec = inner + .filter(|p| p.as_rule() == Rule::enum_variant) + .map(|p| self.parse_enum_variant(p, name.clone())) + .collect::>()?; + + if variants.is_empty() { + log::warn!("Enum '{}' has empty body", name); + } + + Ok(EnumDefinition { name, variants }) + } + + pub fn parse_enum_def_from_type_def(&self, pair: Pair) -> CompileResult { + let mut found_enum = None; + for child in pair.into_inner() { + if child.as_rule() == Rule::enum_def { + found_enum = Some(child); + break; + } + } + + let enum_def_pair = found_enum.ok_or(parse_error!("type_def does not contain enum_def"))?; + + self.parse_enum_def(enum_def_pair) + } + + fn parse_enum_variant( + &self, + pair: Pair, + parent_name: String, + ) -> CompileResult { + let mut identifier_found = None; + let mut args = Vec::new(); + let mut variant_default = None; + + for child in pair.clone().into_inner() { + match child.as_rule() { + Rule::identifier => { + identifier_found = Some(child.as_str().to_string()); + } + Rule::param => { + let field = self.parse_param_field(child)?; + args.push(field); + } + Rule::expr => { + variant_default = Some(self.parse_expr(child)?); + } + Rule::metadata_and_modifiers => { + // Skip - we already checked this + } + other => { + log::debug!("Unknown rule in enum_variant: {:?}", other); + } + } + } + + let name = identifier_found.ok_or(parse_error!("enum variant missing name"))?; + + // If there's a variant-level default, apply it to the last argument + if let Some(default_expr) = variant_default + && let Some(last_arg) = args.last_mut() + { + last_arg.default = Some(default_expr); + } + + Ok(EnumVariant { + name, + parent: parent_name, + args, + default: None, + }) + } + fn parse_struct_field(&self, pair: Pair) -> CompileResult { // var_decl = { (var_kw | const_kw) ~ identifier ~ (":" ~ type_annotation)? ~ ("=" ~ expr)? ~ ";" } let name = Self::extract_first_identifier(pair.clone()) @@ -207,6 +298,28 @@ impl Parser { .collect() } + fn parse_param_field(&self, pair: Pair) -> CompileResult { + // param = { identifier ~ ":" ~ type_annotation ~ ( "=" ~ expr )? } + let mut inner = pair.into_inner(); + let name = inner.next().unwrap().as_str().to_string(); + let type_node = inner.next().unwrap(); + let ty_ann = self.parse_type(type_node)?; + + // Check for optional default expression + let default = inner + .next() + .map(|expr_pair| self.parse_expr(expr_pair)) + .transpose()?; + + Ok(Field { + name, + ty: TypeId::default(), + annotation: Some(ty_ann), + is_const: false, + default, + }) + } + fn parse_block(&self, pair: Pair) -> CompileResult { // block = { "{" ~ (statement)* ~ "}" } let stmts = pair @@ -532,7 +645,7 @@ impl Parser { let mut expr = self.parse_expr(inner.next().unwrap())?; // Handle chained field access (.field1.field2.field3) - while let Some(field_pair) = inner.next() { + for field_pair in inner { if field_pair.as_rule() == Rule::postfix_field { let mut field_inner = field_pair.into_inner(); let field_name = field_inner diff --git a/kitlang/src/codegen/symbols.rs b/kitlang/src/codegen/symbols.rs index c78857c..383304d 100644 --- a/kitlang/src/codegen/symbols.rs +++ b/kitlang/src/codegen/symbols.rs @@ -1,7 +1,16 @@ -use super::type_ast::{Field, StructDefinition}; +use super::type_ast::{EnumDefinition, EnumVariant, Field, StructDefinition}; use super::types::TypeId; use std::collections::HashMap; +/// Stores information about an enum variant for lookup. +#[derive(Clone, Debug)] +pub struct EnumVariantInfo { + pub enum_name: String, + pub variant_name: String, + pub arg_types: Vec, + pub has_defaults: bool, +} + /// Symbol table for tracking variable and function types during inference. /// /// Currently uses a flat scope (no nesting). Variables and functions are tracked @@ -15,6 +24,12 @@ pub struct SymbolTable { /// Maps struct names to their definitions. structs: HashMap, + + /// Maps enum names to their definitions. + enums: HashMap, + + /// Maps qualified variant names ("EnumName.VariantName") to variant info. + enum_variants: HashMap, } impl Default for SymbolTable { @@ -29,6 +44,8 @@ impl SymbolTable { vars: HashMap::new(), functions: HashMap::new(), structs: HashMap::new(), + enums: HashMap::new(), + enum_variants: HashMap::new(), } } @@ -68,4 +85,65 @@ impl SymbolTable { .get(struct_name) .and_then(|s| s.fields.iter().find(|f| f.name == field_name)) } + + /// Define an enum type. + pub fn define_enum(&mut self, def: EnumDefinition) { + self.enums.insert(def.name.clone(), def); + } + + /// Look up an enum definition by name. + pub fn lookup_enum(&self, name: &str) -> Option<&EnumDefinition> { + self.enums.get(name) + } + + /// Define an enum variant constructor. + pub fn define_enum_variant(&mut self, variant: &EnumVariant) { + let qualified_name = format!("{}.{}", variant.parent, variant.name); + let has_defaults = variant.args.iter().any(|f| f.default.is_some()); + let arg_types: Vec = variant.args.iter().map(|f| f.ty).collect(); + + self.enum_variants.insert( + qualified_name, + EnumVariantInfo { + enum_name: variant.parent.clone(), + variant_name: variant.name.clone(), + arg_types, + has_defaults, + }, + ); + } + + /// Look up an enum variant by qualified name ("EnumName.VariantName"). + pub fn lookup_enum_variant(&self, qualified_name: &str) -> Option<&EnumVariantInfo> { + self.enum_variants.get(qualified_name) + } + + /// Look up an enum variant by simple name across all enums. + pub fn lookup_enum_variant_by_simple_name( + &self, + simple_name: &str, + ) -> Option<&EnumVariantInfo> { + self.enum_variants + .values() + .find(|v| v.variant_name == simple_name) + } + + /// Look up an enum variant by enum name and variant name. + pub fn lookup_variant(&self, enum_name: &str, variant_name: &str) -> Option<&EnumVariantInfo> { + let qualified_name = format!("{}.{}", enum_name, variant_name); + self.enum_variants.get(&qualified_name) + } + + /// Get all variant names for an enum. + pub fn get_enum_variants(&self, enum_name: &str) -> Vec<&EnumVariantInfo> { + self.enum_variants + .values() + .filter(|v| v.enum_name == enum_name) + .collect() + } + + /// Get all registered enums. + pub fn get_enums(&self) -> Vec<&EnumDefinition> { + self.enums.values().collect() + } } diff --git a/kitlang/src/codegen/type_ast.rs b/kitlang/src/codegen/type_ast.rs index 9ece785..ed0de06 100644 --- a/kitlang/src/codegen/type_ast.rs +++ b/kitlang/src/codegen/type_ast.rs @@ -23,3 +23,17 @@ pub struct FieldInit { pub name: String, pub value: Expr, } + +#[derive(Clone, Debug, PartialEq)] +pub struct EnumVariant { + pub name: String, + pub parent: String, + pub args: Vec, + pub default: Option, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct EnumDefinition { + pub name: String, + pub variants: Vec, +} From 1c1adb4748c87b1f23037f627359d76b95c9e7ca Mon Sep 17 00:00:00 2001 From: winlogon Date: Sat, 17 Jan 2026 20:06:44 +0100 Subject: [PATCH 4/5] feat(utf8): enforce UTF8 per every file --- kitlang/src/codegen/frontend.rs | 48 ++++++++++++++++----------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/kitlang/src/codegen/frontend.rs b/kitlang/src/codegen/frontend.rs index 41b51df..8a1b1d0 100644 --- a/kitlang/src/codegen/frontend.rs +++ b/kitlang/src/codegen/frontend.rs @@ -43,48 +43,48 @@ impl Compiler { let mut structs = Vec::new(); let mut enums = Vec::new(); - // TODO: track which files are UTF-8 formatted: - // - true = UTF-8 - // - false = binary (NOT ACCEPTED) - // each files correspond to an index in the `files` vector - let _files = vec![false; self.files.len()]; - - for file in &self.files { - let input = std::fs::read_to_string(file).map_err(CompilationError::Io)?; + // Track file encoding (future-proofing) + // true = UTF-8, false = binary (rejected) + let mut files_utf8 = vec![true; self.files.len()]; + + for (idx, file) in self.files.iter().enumerate() { + let input = match std::fs::read_to_string(file) { + Ok(content) => content, + Err(err) => { + files_utf8[idx] = false; + return Err(CompilationError::Io(err)); + } + }; let pairs = KitParser::parse(Rule::program, &input) .map_err(|e| CompilationError::ParseError(e.to_string()))?; for pair in pairs { match pair.as_rule() { - Rule::include_stmt => includes.push(self.parser.parse_include(pair)), - Rule::function_decl => functions.push(self.parser.parse_function(pair)?), - Rule::type_def => { - let mut found_enum = None; - let mut found_struct = None; + Rule::include_stmt => { + includes.push(self.parser.parse_include(pair)); + } - for child in pair.clone().into_inner() { + Rule::function_decl => { + functions.push(self.parser.parse_function(pair)?); + } + + Rule::type_def => { + for child in pair.into_inner() { match child.as_rule() { Rule::enum_def => { - found_enum = Some(child); + enums.push(self.parser.parse_enum_def(child)?); break; } Rule::struct_def => { - found_struct = Some(child); + structs.push(self.parser.parse_struct_def(child)?); break; } _ => {} } } - - if let Some(enum_pair) = found_enum { - let enum_def = self.parser.parse_enum_def(enum_pair)?; - enums.push(enum_def); - } else if let Some(struct_pair) = found_struct { - let struct_def = self.parser.parse_struct_def(struct_pair)?; - structs.push(struct_def); - } } + _ => {} } } From 47947e860deac8c47b6d6a6be7293e3458cf14c7 Mon Sep 17 00:00:00 2001 From: winlogon Date: Tue, 20 Jan 2026 03:12:31 +0100 Subject: [PATCH 5/5] feat(enum): support partial enum constructors This commit adds support for partially instantiating enum variants by automatically filling in missing arguments using their default values. This aligns enum construction with user expectations and improves ergonomics when working with variants that define defaults. Summary of changes: - Add example and expected output for partial enum usage - Inline default arguments when generating constructor calls - Fix argument assignment logic in enum codegen - Update type inference to backfill missing default args --- examples/enum_partial.kit | 19 ++++++++++++ examples/enum_partial.kit.expected | 4 +++ kitlang/src/codegen/frontend.rs | 48 ++++++++++++++++++++++++------ kitlang/src/codegen/inference.rs | 35 ++++++++++------------ 4 files changed, 78 insertions(+), 28 deletions(-) create mode 100644 examples/enum_partial.kit create mode 100644 examples/enum_partial.kit.expected diff --git a/examples/enum_partial.kit b/examples/enum_partial.kit new file mode 100644 index 0000000..914315e --- /dev/null +++ b/examples/enum_partial.kit @@ -0,0 +1,19 @@ +include "stdio.h"; + +enum TestEnum { + NoArgs; + WithDefault(a: Int, b: Int = 100); + WithTwoDefaults(x: Float, y: Float = 2.5, z: Int = 42); +} + +function main() { + // Test partial instantiation + var e1 = WithDefault(10); // Should use b=100 + var e2 = WithTwoDefaults(1.5); // Should use y=2.5, z=42 + var e3 = WithTwoDefaults(1.5, 3.0); // Should use z=42 + + printf("Partial constructors work!\n"); + printf("e1.a = %i, e1.b = %i\n", 10, 100); + printf("e2.x = %f, e2.y = %f, e2.z = %i\n", 1.5, 2.5, 42); + printf("e3.x = %f, e3.y = %f, e3.z = %i\n", 1.5, 3.0, 42); +} diff --git a/examples/enum_partial.kit.expected b/examples/enum_partial.kit.expected new file mode 100644 index 0000000..0e1b636 --- /dev/null +++ b/examples/enum_partial.kit.expected @@ -0,0 +1,4 @@ +Partial constructors work! +e1.a = 10, e1.b = 100 +e2.x = 1.5, e2.y = 2.5, e2.z = 42 +e3.x = 1.5, e3.y = 3.0, e3.z = 42 \ No newline at end of file diff --git a/kitlang/src/codegen/frontend.rs b/kitlang/src/codegen/frontend.rs index 8a1b1d0..49248f6 100644 --- a/kitlang/src/codegen/frontend.rs +++ b/kitlang/src/codegen/frontend.rs @@ -243,6 +243,17 @@ impl Compiler { ) } + /// Lowers a Kit enum definition into its C representation. + /// + /// Simple enums (variants without associated data) are emitted as plain C `enum`s. + /// Enums with data-carrying variants are compiled into a tagged-union layout: + /// - a discriminant `enum` to track the active variant, + /// - one `struct` per data-carrying variant, + /// - a top-level `struct` containing the discriminant and a `union` of variant data. + /// + /// For variants with fields, constructor functions are generated to initialize the + /// correct discriminant and populate the union safely, avoiding error-prone manual + /// initialization in C. fn generate_enum_declaration(&self, enum_def: &EnumDefinition) -> String { let mut output = String::new(); @@ -347,17 +358,18 @@ impl Compiler { }) .collect(); - let _arg_names: Vec = v.args.iter().map(|arg| arg.name.clone()).collect(); + let arg_names: Vec = v.args.iter().map(|arg| arg.name.clone()).collect(); let assignments: Vec = v .args .iter() - .map(|arg| { + .enumerate() + .map(|(i, arg)| { format!( " result._variant.{}.{} = {};", v.name.to_lowercase(), arg.name, - arg.name + arg_names[i] ) }) .collect(); @@ -679,12 +691,30 @@ impl Compiler { ) } } else { - // Complex variant - call the constructor - let args_str = args - .iter() - .map(|a| self.transpile_expr(a)) - .collect::>() - .join(", "); + // Complex variant - call the constructor with defaults inlined + let enum_def = self.inferencer.symbols().lookup_enum(enum_name); + let variant_def = + enum_def.and_then(|e| e.variants.iter().find(|v| v.name == *variant_name)); + + let args_str = if let Some(variant) = variant_def { + let mut full_args = args.clone(); + for i in args.len()..variant.args.len() { + if let Some(default) = &variant.args[i].default { + full_args.push(default.clone()); + } + } + full_args + .iter() + .map(|a| self.transpile_expr(a)) + .collect::>() + .join(", ") + } else { + args.iter() + .map(|a| self.transpile_expr(a)) + .collect::>() + .join(", ") + }; + format!("{}_{}_new({})", enum_name, variant_name, args_str) } } diff --git a/kitlang/src/codegen/inference.rs b/kitlang/src/codegen/inference.rs index 09cad1e..6899364 100644 --- a/kitlang/src/codegen/inference.rs +++ b/kitlang/src/codegen/inference.rs @@ -1,5 +1,8 @@ +use std::collections::HashSet; + +use super::Field; use super::ast::{Block, Expr, Function, Literal, Program, Stmt}; -use super::symbols::SymbolTable; +use super::symbols::{EnumVariantInfo, SymbolTable}; use super::type_ast::{EnumDefinition, FieldInit, StructDefinition}; use super::types::{BinaryOperator, Type, TypeId, TypeStore, UnaryOperator}; use crate::error::{CompilationError, CompileResult}; @@ -69,7 +72,7 @@ impl TypeInferencer { } else { self.store.new_unknown() }; - updated_fields.push(super::type_ast::Field { + updated_fields.push(Field { name: field.name.clone(), ty: field_type_id, annotation: field.annotation.clone(), @@ -79,7 +82,7 @@ impl TypeInferencer { } // Create updated struct definition with resolved field types - let updated_struct_def = super::type_ast::StructDefinition { + let updated_struct_def = StructDefinition { name: struct_def.name.clone(), fields: updated_fields, }; @@ -552,7 +555,7 @@ impl TypeInferencer { }; // Build set of provided field names for validation - let provided_field_names: std::collections::HashSet = + let provided_field_names: HashSet = fields.iter().map(|f| f.name.clone()).collect(); // Validate all provided fields exist in struct @@ -578,7 +581,7 @@ impl TypeInferencer { } // Collect field info we need (release struct_def borrow afterwards) - let field_infos: Vec<(String, Option, Option)> = struct_def + let field_infos: Vec<(String, Option, Option)> = struct_def .fields .iter() .map(|f| (f.name.clone(), f.annotation.clone(), f.default.clone())) @@ -634,7 +637,7 @@ impl TypeInferencer { } => { let container_ty = self.infer_expr(expr)?; - // Resolve container type - handle both Struct and Named types + // Resolve container type, handle both Struct and Named types let resolved = self.store.resolve(container_ty)?; // For Named types, we need to look up the struct or enum definition @@ -787,8 +790,8 @@ impl TypeInferencer { /// Follows the Haskell compiler's `addDefaultArgs` function. fn resolve_default_args( &self, - variant_info: &super::symbols::EnumVariantInfo, - enum_def: &super::type_ast::EnumDefinition, + variant_info: &EnumVariantInfo, + enum_def: &EnumDefinition, provided_args: &[Expr], ) -> CompileResult> { let total_required = variant_info.arg_types.len(); @@ -806,21 +809,15 @@ impl TypeInferencer { )) })?; + let provided_len = result.len(); for i in (0..total_required).rev() { - if i >= result.len() { - if let Some(default_expr) = variant.args.get(i).and_then(|f| f.default.as_ref()) - { - result.push(default_expr.clone()); - } else { - return Err(CompilationError::TypeError(format!( - "Missing required argument {} for variant '{}' (no default value)", - i, variant_info.variant_name - ))); - } + if i >= provided_len + && let Some(default_expr) = variant.args.get(i).and_then(|f| f.default.as_ref()) + { + result.push(default_expr.clone()); } } } - Ok(result) }