diff --git a/.gitignore b/.gitignore index 896acd8f..f9bd3b06 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ Cargo.lock *.wat out test_data/**/*.json +.* \ No newline at end of file diff --git a/bin/linux/inf-llc b/bin/linux/inf-llc new file mode 100755 index 00000000..c7ff86c0 Binary files /dev/null and b/bin/linux/inf-llc differ diff --git a/bin/linux/rust-lld b/bin/linux/rust-lld new file mode 100755 index 00000000..0bc981a7 Binary files /dev/null and b/bin/linux/rust-lld differ diff --git a/core/ast/src/builder.rs b/core/ast/src/builder.rs index 591042ff..ce83aeb4 100644 --- a/core/ast/src/builder.rs +++ b/core/ast/src/builder.rs @@ -5,6 +5,7 @@ use crate::nodes::{ TypeMemberAccessExpression, }; use crate::type_infer::TypeChecker; +use crate::type_info::TypeInfo; use crate::{ arena::Arena, nodes::{ @@ -111,10 +112,12 @@ impl<'a> Builder<'a, InitState> { let mut type_checker = TypeChecker::new(); let _ = type_checker.infer_types(&mut res); - // let symbol_table = SymbolTable::build(&res, &self.types, &self.arena); + // let mut type_checker = TypeChecker::new(); let t_ast = TypedAst::new(res, self.arena.clone()); + t_ast.infer_expression_types(); // run type inference over all expressions - // type_infer::traverse_source_files(&t_ast.source_files, &t_ast.symbol_table) + // type_checker + // .infer_types(&t_ast.source_files) // .map_err(|e| anyhow::Error::msg(format!("Type error: {e:?}")))?; Ok(Builder { arena: Arena::default(), @@ -1257,6 +1260,9 @@ impl<'a> Builder<'a, InitState> { node.utf8_text(code).unwrap().to_string() }; let node = Rc::new(SimpleType::new(id, location, name)); + node.type_info + .borrow_mut() + .replace(TypeInfo::new(&Type::Simple(node.clone()))); self.arena.add_node( AstNode::Expression(Expression::Type(Type::Simple(node.clone()))), parent_id, diff --git a/core/ast/src/nodes_impl.rs b/core/ast/src/nodes_impl.rs index b5e9339c..6655fa37 100644 --- a/core/ast/src/nodes_impl.rs +++ b/core/ast/src/nodes_impl.rs @@ -4,7 +4,7 @@ use crate::{ nodes::{ ArgumentType, IgnoreArgument, SelfReference, StructExpression, TypeMemberAccessExpression, }, - type_info::{TypeInfo, TypeInfoKind}, + type_info::{NumberTypeKindNumberType, TypeInfo, TypeInfoKind}, }; use super::nodes::{ @@ -120,6 +120,63 @@ impl BlockType { | BlockType::Unique(block) => block.statements.clone(), } } + #[must_use] + pub fn is_non_det(&self) -> bool { + match self { + BlockType::Block(block) => block + .statements + .iter() + .any(super::nodes::Statement::is_non_det), + _ => true, + } + } + #[must_use] + pub fn is_void(&self) -> bool { + let fn_find_ret_stmt = |statements: &Vec| -> bool { + for stmt in statements { + match stmt { + Statement::Return(_) => return true, + Statement::Block(block_type) => { + if block_type.is_void() { + return true; + } + } + _ => {} + } + } + false + }; + !fn_find_ret_stmt(&self.statements()) + } +} + +impl Statement { + #[must_use] + pub fn is_non_det(&self) -> bool { + match self { + Statement::Block(block_type) => !matches!(block_type, BlockType::Block(_)), + Statement::Expression(expr_stmt) => expr_stmt.is_non_det(), + Statement::Return(ret_stmt) => ret_stmt.expression.borrow().is_non_det(), + Statement::Loop(loop_stmt) => loop_stmt + .condition + .borrow() + .as_ref() + .is_some_and(super::nodes::Expression::is_non_det), + Statement::If(if_stmt) => { + if_stmt.condition.borrow().is_non_det() + || if_stmt.if_arm.is_non_det() + || if_stmt + .else_arm + .as_ref() + .is_some_and(super::nodes::BlockType::is_non_det) + } + Statement::VariableDefinition(var_def) => var_def + .value + .as_ref() + .is_some_and(|value| value.borrow().is_non_det()), + _ => false, + } + } } impl Expression { @@ -140,6 +197,10 @@ impl Expression { Expression::Uzumaki(e) => e.type_info.borrow().clone(), } } + #[must_use] + pub fn is_non_det(&self) -> bool { + matches!(self, Expression::Uzumaki(_)) + } } impl Literal { @@ -345,6 +406,11 @@ impl FunctionDefinition { .as_ref() .is_none_or(super::nodes::Type::is_unit_type) } + + #[must_use] + pub fn is_non_det(&self) -> bool { + self.body.is_non_det() + } } impl ExternalFunctionDefinition { @@ -690,6 +756,26 @@ impl UzumakiExpression { type_info: RefCell::new(None), } } + #[must_use] + pub fn is_i32(&self) -> bool { + if let Some(type_info) = self.type_info.borrow().as_ref() { + return matches!( + type_info.kind, + TypeInfoKind::Number(NumberTypeKindNumberType::I32) + ); + } + false + } + #[must_use] + pub fn is_i64(&self) -> bool { + if let Some(type_info) = self.type_info.borrow().as_ref() { + return matches!( + type_info.kind, + TypeInfoKind::Number(NumberTypeKindNumberType::I64) + ); + } + false + } } impl AssertStatement { diff --git a/core/ast/src/t_ast.rs b/core/ast/src/t_ast.rs index c471a1ae..0caab902 100644 --- a/core/ast/src/t_ast.rs +++ b/core/ast/src/t_ast.rs @@ -1,6 +1,7 @@ use crate::{ arena::Arena, - nodes::{AstNode, SourceFile}, + nodes::{AstNode, Definition, Expression, SourceFile, Statement}, + type_info::TypeInfo, }; #[derive(Clone, Default)] @@ -26,4 +27,35 @@ impl TypedAst { .cloned() .collect() } + + pub fn infer_expression_types(&self) { + //FIXME: very hacky way to infer Uzumaki expression types in return statements + for function_def_node in + self.filter_nodes(|node| matches!(node, AstNode::Definition(Definition::Function(_)))) + { + let AstNode::Definition(Definition::Function(function_def)) = function_def_node else { + unreachable!() + }; + if function_def.is_void() { + continue; + } + if let Some(Statement::Return(last_stmt)) = function_def.body.statements().last() { + if !matches!(*last_stmt.expression.borrow(), Expression::Uzumaki(_)) { + continue; + } + + match &*last_stmt.expression.borrow() { + Expression::Uzumaki(expr) => { + if expr.type_info.borrow().is_some() { + continue; + } + if let Some(return_type) = &function_def.returns { + expr.type_info.replace(Some(TypeInfo::new(return_type))); + } + } + _ => unreachable!(), + } + } + } + } } diff --git a/core/wasm-codegen/Cargo.toml b/core/wasm-codegen/Cargo.toml index 0842c811..4e9ec557 100644 --- a/core/wasm-codegen/Cargo.toml +++ b/core/wasm-codegen/Cargo.toml @@ -7,9 +7,8 @@ homepage = { workspace = true } repository = { workspace = true } [dependencies] -wasm-encoder="0.240.0" +inkwell = { version = "0.7.1", features = ["llvm21-1"] } +tempfile = "3.3.0" +which = "8.0.0" inference-ast.workspace = true anyhow.workspace = true - -[dev-dependencies] -wasmtime="38.0.4" diff --git a/core/wasm-codegen/build.rs b/core/wasm-codegen/build.rs new file mode 100644 index 00000000..68e51f92 --- /dev/null +++ b/core/wasm-codegen/build.rs @@ -0,0 +1,109 @@ +use std::env; +use std::fs; +use std::path::PathBuf; + +fn main() { + let platform = if cfg!(target_os = "linux") { + "linux" + } else if cfg!(target_os = "macos") { + "macos" + } else if cfg!(target_os = "windows") { + "windows" + } else { + panic!("Unsupported platform"); + }; + + let exe_suffix = std::env::consts::EXE_SUFFIX; + let llc_binary = format!("inf-llc{exe_suffix}"); + let rust_lld_binary = format!("rust-lld{exe_suffix}"); + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let workspace_root = manifest_dir + .parent() // core/ + .and_then(|p| p.parent()) // workspace root + .expect("Failed to determine workspace root"); + + let source_llc = workspace_root.join("bin").join(platform).join(&llc_binary); + let source_rust_lld = workspace_root + .join("bin") + .join(platform) + .join(&rust_lld_binary); + + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let target_profile_dir = out_dir + .parent() // build/- + .and_then(|p| p.parent()) // build/ + .and_then(|p| p.parent()) // target// + .expect("Failed to determine target profile directory"); + + let bin_dir = target_profile_dir.join("bin"); + let dest_llc = bin_dir.join(&llc_binary); + let dest_rust_lld = bin_dir.join(&rust_lld_binary); + + if source_llc.exists() { + if !bin_dir.exists() { + fs::create_dir_all(&bin_dir).expect("Failed to create bin directory"); + } + + fs::copy(&source_llc, &dest_llc).unwrap_or_else(|e| { + panic!( + "Failed to copy inf-llc from {} to {}: {}", + source_llc.display(), + dest_llc.display(), + e + ) + }); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&dest_llc) + .expect("Failed to read inf-llc metadata") + .permissions(); + perms.set_mode(0o755); + fs::set_permissions(&dest_llc, perms).expect("Failed to set executable permissions"); + } + + println!("cargo:info=Copied inf-llc to {}", dest_llc.display()); + } else { + println!( + "cargo:info=inf-llc not found at {}, skipping copy", + source_llc.display() + ); + } + + if source_rust_lld.exists() { + if !bin_dir.exists() { + fs::create_dir_all(&bin_dir).expect("Failed to create bin directory"); + } + + fs::copy(&source_rust_lld, &dest_rust_lld).unwrap_or_else(|e| { + panic!( + "Failed to copy rust-lld from {} to {}: {}", + source_rust_lld.display(), + dest_rust_lld.display(), + e + ) + }); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&dest_rust_lld) + .expect("Failed to read rust-lld metadata") + .permissions(); + perms.set_mode(0o755); + fs::set_permissions(&dest_rust_lld, perms) + .expect("Failed to set executable permissions"); + } + + println!("cargo:info=Copied rust-lld to {}", dest_rust_lld.display()); + } else { + println!( + "cargo:info=rust-lld not found at {}, skipping copy", + source_rust_lld.display() + ); + } + + println!("cargo:rerun-if-changed={}", source_llc.display()); +} diff --git a/core/wasm-codegen/src/compiler.rs b/core/wasm-codegen/src/compiler.rs new file mode 100644 index 00000000..8869340b --- /dev/null +++ b/core/wasm-codegen/src/compiler.rs @@ -0,0 +1,468 @@ +//TODO: don't forget to remove +#![allow(dead_code)] +use crate::utils; +use inference_ast::{ + nodes::{BlockType, Expression, FunctionDefinition, Literal, Statement, Type}, + type_info::{NumberTypeKindNumberType, TypeInfoKind}, +}; +use inkwell::{ + attributes::{Attribute, AttributeLoc}, + builder::Builder, + context::Context, + module::Module, + types::BasicTypeEnum, + values::{FunctionValue, PointerValue}, +}; +use std::{cell::RefCell, collections::HashMap, iter::Peekable, rc::Rc}; + +const UZUMAKI_I32_INTRINSIC: &str = "llvm.wasm.uzumaki.i32"; +const UZUMAKI_I64_INTRINSIC: &str = "llvm.wasm.uzumaki.i64"; +const FORALL_START_INTRINSIC: &str = "llvm.wasm.forall.start"; +const FORALL_END_INTRINSIC: &str = "llvm.wasm.forall.end"; +const EXISTS_START_INTRINSIC: &str = "llvm.wasm.exists.start"; +const EXISTS_END_INTRINSIC: &str = "llvm.wasm.exists.end"; +const ASSUME_START_INTRINSIC: &str = "llvm.wasm.assume.start"; +const ASSUME_END_INTRINSIC: &str = "llvm.wasm.assume.end"; +const UNIQUE_START_INTRINSIC: &str = "llvm.wasm.unique.start"; +const UNIQUE_END_INTRINSIC: &str = "llvm.wasm.unique.end"; + +pub(crate) struct Compiler<'ctx> { + context: &'ctx Context, + module: Module<'ctx>, + builder: Builder<'ctx>, + variables: RefCell, BasicTypeEnum<'ctx>)>>, +} + +impl<'ctx> Compiler<'ctx> { + pub(crate) fn new(context: &'ctx Context, module_name: &str) -> Self { + let module = context.create_module(module_name); + let builder = context.create_builder(); + + Self { + context, + module, + builder, + variables: RefCell::new(HashMap::new()), + } + } + + fn add_optimization_barriers(&self, function: FunctionValue<'ctx>) { + let attr_kind_optnone = Attribute::get_named_enum_kind_id("optnone"); + let attr_kind_noinline = Attribute::get_named_enum_kind_id("noinline"); + + let optnone = self.context.create_enum_attribute(attr_kind_optnone, 0); + let noinline = self.context.create_enum_attribute(attr_kind_noinline, 0); + + function.add_attribute(AttributeLoc::Function, optnone); + function.add_attribute(AttributeLoc::Function, noinline); + } + + pub(crate) fn visit_function_definition(&self, function_definition: &Rc) { + let fn_name = function_definition.name(); + let fn_type = match &function_definition.returns { + Some(ret_type) => match ret_type { + Type::Array(_array_type) => todo!(), + Type::Simple(simple_type) => match simple_type.name.to_lowercase().as_str() { + "i32" => self.context.i32_type().fn_type(&[], false), + "i64" => self.context.i64_type().fn_type(&[], false), + "u32" => todo!(), + "u64" => todo!(), + _ => panic!("Unsupported return type: {}", simple_type.name), + }, + Type::Generic(_generic_type) => todo!(), + Type::Function(_function_type) => todo!(), + Type::QualifiedName(_qualified_name) => todo!(), + Type::Qualified(_type_qualified_name) => todo!(), + Type::Custom(_identifier) => todo!(), + }, + None => self.context.void_type().fn_type(&[], false), + }; + let function = self.module.add_function(fn_name.as_str(), fn_type, None); + + let export_name_attr = self + .context + .create_string_attribute("wasm-export-name", fn_name.as_str()); + function.add_attribute(AttributeLoc::Function, export_name_attr); + if function_definition.is_non_det() { + self.add_optimization_barriers(function); + } + let entry = self.context.append_basic_block(function, "entry"); + self.builder.position_at_end(entry); + self.lower_statement( + std::iter::once(Statement::Block(function_definition.body.clone())).peekable(), + &mut vec![function_definition.body.clone()], + ); + if function_definition.is_void() { + self.builder.build_return(None).unwrap(); + } + } + + #[allow(clippy::too_many_lines)] + fn lower_statement>( + &self, + mut statements_iterator: Peekable, + parent_blocks_stack: &mut Vec, + ) { + let statement = statements_iterator.next().unwrap(); + match statement { + Statement::Block(block_type) => match block_type { + BlockType::Block(block) => { + parent_blocks_stack.push(BlockType::Block(block.clone())); + for stmt in block.statements.clone() { + self.lower_statement(std::iter::once(stmt).peekable(), parent_blocks_stack); + } + parent_blocks_stack.pop(); + } + BlockType::Forall(forall_block) => { + let forall_start = self.forall_start_intrinsic(); + self.builder + .build_call(forall_start, &[], "") + .expect("Failed to build forall intrinsic call"); + parent_blocks_stack.push(BlockType::Forall(forall_block.clone())); + for stmt in forall_block.statements.clone() { + self.lower_statement(std::iter::once(stmt).peekable(), parent_blocks_stack); + } + let forall_end = self.forall_end_intrinsic(); + self.builder + .build_call(forall_end, &[], "") + .expect("Failed to build forall end intrinsic call"); + parent_blocks_stack.pop(); + } + BlockType::Assume(assume_block) => { + let assume_start = self.assume_start_intrinsic(); + self.builder + .build_call(assume_start, &[], "") + .expect("Failed to build assume intrinsic call"); + parent_blocks_stack.push(BlockType::Assume(assume_block.clone())); + for stmt in assume_block.statements.clone() { + self.lower_statement(std::iter::once(stmt).peekable(), parent_blocks_stack); + } + let assume_end = self.assume_end_intrinsic(); + self.builder + .build_call(assume_end, &[], "") + .expect("Failed to build assume end intrinsic call"); + parent_blocks_stack.pop(); + } + BlockType::Exists(exists_block) => { + let exists_start = self.exists_start_intrinsic(); + self.builder + .build_call(exists_start, &[], "") + .expect("Failed to build exists intrinsic call"); + parent_blocks_stack.push(BlockType::Exists(exists_block.clone())); + for stmt in exists_block.statements.clone() { + self.lower_statement(std::iter::once(stmt).peekable(), parent_blocks_stack); + } + let exists_end = self.exists_end_intrinsic(); + self.builder + .build_call(exists_end, &[], "") + .expect("Failed to build exists end intrinsic call"); + parent_blocks_stack.pop(); + } + BlockType::Unique(unique_block) => { + let unique_start = self.unique_start_intrinsic(); + self.builder + .build_call(unique_start, &[], "") + .expect("Failed to build unique intrinsic call"); + parent_blocks_stack.push(BlockType::Unique(unique_block.clone())); + for stmt in unique_block.statements.clone() { + self.lower_statement(std::iter::once(stmt).peekable(), parent_blocks_stack); + } + let unique_end = self.unique_end_intrinsic(); + self.builder + .build_call(unique_end, &[], "") + .expect("Failed to build unique end intrinsic call"); + parent_blocks_stack.pop(); + } + }, + Statement::Expression(expression) => { + let expr = self.lower_expression(&expression); + //FIXME: revisit this logic #45 + if statements_iterator.peek().is_none() + && parent_blocks_stack.first().unwrap().is_non_det() + && parent_blocks_stack.first().unwrap().is_void() + { + let local = self.builder.build_alloca(expr.get_type(), "temp").unwrap(); + self.builder.build_store(local, expr).unwrap(); + } + } + Statement::Assign(_assign_statement) => todo!(), + Statement::Return(return_statement) => { + let ret = self.lower_expression(&return_statement.expression.borrow()); + self.builder.build_return(Some(&ret)).unwrap(); + } + Statement::Loop(_loop_statement) => todo!(), + Statement::Break(_break_statement) => todo!(), + Statement::If(_if_statement) => todo!(), + Statement::VariableDefinition(_variable_definition_statement) => { + // let ctx_type = self.context.i32_type(); //TODO: support other types + // if let Some(value) = &variable_definition_statement.value { + // if matches!(*value.borrow(), Expression::Uzumaki(_)) + // || matches!(*value.borrow(), Expression::Literal(_)) + // { + // } else { + // todo!() + // } + // } + } + Statement::TypeDefinition(_type_definition_statement) => todo!(), + Statement::Assert(_assert_statement) => todo!(), + Statement::ConstantDefinition(constant_definition) => match &constant_definition.ty { + Type::Array(_type_array) => todo!(), + Type::Simple(simple_type) => { + match &simple_type + .type_info + .borrow() + .as_ref() + .expect("SimpleType should have type_info set") + .kind + { + TypeInfoKind::Unit => todo!(), + TypeInfoKind::Bool => todo!(), + TypeInfoKind::String => todo!(), + TypeInfoKind::Number(number_type_kind_number_type) => { + match number_type_kind_number_type { + NumberTypeKindNumberType::I8 => todo!(), + NumberTypeKindNumberType::I16 => todo!(), + NumberTypeKindNumberType::I32 => { + let ctx_type = self.context.i32_type(); + match &constant_definition.value { + Literal::Number(number_literal) => { + let val = ctx_type.const_int( + number_literal.value.parse::().unwrap_or(0), + false, + ); + let local = self + .builder + .build_alloca(ctx_type, &constant_definition.name()) + .unwrap(); + self.builder.build_store(local, val).unwrap(); + self.variables.borrow_mut().insert( + constant_definition.name(), + (local, ctx_type.into()), + ); + } + _ => panic!( + "Constant value for i32 should be a number literal. Found: {:?}", + constant_definition.value + ), + } + } + NumberTypeKindNumberType::I64 => todo!(), + NumberTypeKindNumberType::U8 => todo!(), + NumberTypeKindNumberType::U16 => todo!(), + NumberTypeKindNumberType::U32 => todo!(), + NumberTypeKindNumberType::U64 => todo!(), + } + } + TypeInfoKind::Custom(_) => todo!(), + TypeInfoKind::Array(_type_info, _) => todo!(), + TypeInfoKind::Generic(_) => todo!(), + TypeInfoKind::QualifiedName(_) => todo!(), + TypeInfoKind::Qualified(_) => todo!(), + TypeInfoKind::Function(_) => todo!(), + TypeInfoKind::Struct(_) => todo!(), + TypeInfoKind::Enum(_) => todo!(), + TypeInfoKind::Spec(_) => todo!(), + } + } + Type::Generic(_generic_type) => todo!(), + Type::Function(_function_type) => todo!(), + Type::QualifiedName(_qualified_name) => todo!(), + Type::Qualified(_type_qualified_name) => todo!(), + Type::Custom(_identifier) => todo!(), + }, + } + } + + fn lower_expression(&self, expression: &Expression) -> inkwell::values::IntValue<'ctx> { + match expression { + Expression::ArrayIndexAccess(_array_index_access_expression) => todo!(), + Expression::Binary(_binary_expression) => todo!(), + Expression::MemberAccess(_member_access_expression) => todo!(), + Expression::TypeMemberAccess(_type_member_access_expression) => todo!(), + Expression::FunctionCall(_function_call_expression) => todo!(), + Expression::Struct(_struct_expression) => todo!(), + Expression::PrefixUnary(_prefix_unary_expression) => todo!(), + Expression::Parenthesized(_parenthesized_expression) => todo!(), + Expression::Literal(literal) => self.lower_literal(literal), + Expression::Identifier(identifier) => { + let (ptr, ty) = self + .variables + .borrow() + .get(&identifier.name) + .copied() + .expect("Variable not found"); + self.builder + .build_load(ty, ptr, &identifier.name) + .unwrap() + .into_int_value() + } + Expression::Type(_) => todo!(), + Expression::Uzumaki(uzumaki_expression) => { + if uzumaki_expression.is_i32() { + return self.lower_uzumaki_i32_expression(); + } + if uzumaki_expression.is_i64() { + return self.lower_uzumaki_i64_expression(); + } + panic!("Unsupported Uzumaki expression type: {uzumaki_expression:?}"); + } + } + } + + fn lower_literal(&self, literal: &Literal) -> inkwell::values::IntValue<'ctx> { + match literal { + Literal::Array(_array_literal) => todo!(), + Literal::Bool(bool_literal) => self + .context + .i32_type() + .const_int(u64::from(bool_literal.value), false), + Literal::String(_string_literal) => todo!(), + Literal::Number(number_literal) => self + .context + .i32_type() + .const_int(number_literal.value.parse::().unwrap_or(0), false), + Literal::Unit(_unit_literal) => todo!(), + } + } + + fn lower_uzumaki_i32_expression(&self) -> inkwell::values::IntValue<'ctx> { + let uzumaki_i32_intr = self.uzumaki_i32_intrinsic(); + let call = self + .builder + .build_call(uzumaki_i32_intr, &[], "uz_i32") + .expect("Failed to build uzumaki_i32_intrinsic call"); + let call_kind = call.try_as_basic_value(); + let basic = call_kind.unwrap_basic(); + basic.into_int_value() + } + + fn lower_uzumaki_i64_expression(&self) -> inkwell::values::IntValue<'ctx> { + let uzumaki_i64_intr = self.uzumaki_i64_intrinsic(); + let call = self + .builder + .build_call(uzumaki_i64_intr, &[], "uz_i64") + .expect("Failed to build uzumaki_i64_intrinsic call"); + let call_kind = call.try_as_basic_value(); + let basic = call_kind.unwrap_basic(); + basic.into_int_value() + } + + fn uzumaki_i32_intrinsic(&self) -> FunctionValue<'ctx> { + let i32_type = self.context.i32_type(); + let fn_type = i32_type.fn_type(&[], false); + self.module + .get_function(UZUMAKI_I32_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(UZUMAKI_I32_INTRINSIC, fn_type, None) + }) + } + + fn uzumaki_i64_intrinsic(&self) -> FunctionValue<'ctx> { + let i64_type = self.context.i64_type(); + let fn_type = i64_type.fn_type(&[], false); + self.module + .get_function(UZUMAKI_I64_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(UZUMAKI_I64_INTRINSIC, fn_type, None) + }) + } + + fn forall_start_intrinsic(&self) -> FunctionValue<'ctx> { + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[], false); + self.module + .get_function(FORALL_START_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(FORALL_START_INTRINSIC, fn_type, None) + }) + } + + fn forall_end_intrinsic(&self) -> FunctionValue<'ctx> { + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[], false); + self.module + .get_function(FORALL_END_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(FORALL_END_INTRINSIC, fn_type, None) + }) + } + + fn exists_start_intrinsic(&self) -> FunctionValue<'ctx> { + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[], false); + self.module + .get_function(EXISTS_START_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(EXISTS_START_INTRINSIC, fn_type, None) + }) + } + + fn exists_end_intrinsic(&self) -> FunctionValue<'ctx> { + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[], false); + self.module + .get_function(EXISTS_END_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(EXISTS_END_INTRINSIC, fn_type, None) + }) + } + + fn assume_start_intrinsic(&self) -> FunctionValue<'ctx> { + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[], false); + self.module + .get_function(ASSUME_START_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(ASSUME_START_INTRINSIC, fn_type, None) + }) + } + + fn assume_end_intrinsic(&self) -> FunctionValue<'ctx> { + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[], false); + self.module + .get_function(ASSUME_END_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(ASSUME_END_INTRINSIC, fn_type, None) + }) + } + + fn unique_start_intrinsic(&self) -> FunctionValue<'ctx> { + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[], false); + self.module + .get_function(UNIQUE_START_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(UNIQUE_START_INTRINSIC, fn_type, None) + }) + } + + fn unique_end_intrinsic(&self) -> FunctionValue<'ctx> { + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[], false); + self.module + .get_function(UNIQUE_END_INTRINSIC) + .unwrap_or_else(|| { + self.module + .add_function(UNIQUE_END_INTRINSIC, fn_type, None) + }) + } + + pub(crate) fn compile_to_wasm( + &self, + output_fname: &str, + optimization_level: u32, + ) -> anyhow::Result> { + utils::compile_to_wasm(&self.module, output_fname, optimization_level) + } +} diff --git a/core/wasm-codegen/src/lib.rs b/core/wasm-codegen/src/lib.rs index 2b89b864..7e931603 100644 --- a/core/wasm-codegen/src/lib.rs +++ b/core/wasm-codegen/src/lib.rs @@ -1,7 +1,15 @@ #![warn(clippy::pedantic)] + use inference_ast::t_ast::TypedAst; +use inkwell::{ + context::Context, + targets::{InitializationConfig, Target}, +}; + +use crate::compiler::Compiler; -pub mod module; +mod compiler; +mod utils; /// Generates WebAssembly bytecode from a typed AST. /// @@ -12,19 +20,27 @@ pub mod module; /// /// Returns an error if code generation fails. pub fn codegen(t_ast: &TypedAst) -> anyhow::Result> { - let mut builder = module::WasmModuleBuilder::new(); + Target::initialize_webassembly(&InitializationConfig::default()); + let context = Context::create(); + let compiler = Compiler::new(&context, "wasm_module"); + if t_ast.source_files.is_empty() { - return Ok(builder.finish()); + return compiler.compile_to_wasm("output.wasm", 3); } if t_ast.source_files.len() > 1 { todo!("Multi-file support not yet implemented"); } - let source_file = &t_ast.source_files[0]; - for func_def in source_file.function_definitions() { - let _ = builder.push_function(&func_def); - } + traverse_t_ast_with_compiler(t_ast, &compiler); - let wasm_bytes = builder.finish(); + let wasm_bytes = compiler.compile_to_wasm("output.wasm", 3)?; Ok(wasm_bytes) } + +fn traverse_t_ast_with_compiler(t_ast: &TypedAst, compiler: &Compiler) { + for source_file in &t_ast.source_files { + for func_def in source_file.function_definitions() { + compiler.visit_function_definition(&func_def); + } + } +} diff --git a/core/wasm-codegen/src/module.rs b/core/wasm-codegen/src/module.rs deleted file mode 100644 index b6878d0d..00000000 --- a/core/wasm-codegen/src/module.rs +++ /dev/null @@ -1,248 +0,0 @@ -use std::rc::Rc; - -use inference_ast::nodes::{Expression, FunctionDefinition, Literal, Statement, Type}; -use wasm_encoder::{ - CodeSection, ExportKind, ExportSection, Function, FunctionSection, Module, TypeSection, ValType, -}; - -#[derive(Default)] -pub struct WasmModuleBuilder { - module: Module, - types: TypeSection, - functions: FunctionSection, - exports: ExportSection, - codes: CodeSection, - - next_type_index: u32, -} - -impl WasmModuleBuilder { - #[must_use] - pub fn new() -> Self { - Self { - module: Module::new(), - types: TypeSection::new(), - functions: FunctionSection::new(), - exports: ExportSection::new(), - codes: CodeSection::new(), - next_type_index: 0, - } - } - - #[allow(unused_variables)] - pub fn push_function(&mut self, function_definition: &Rc) -> u32 { - let function_type_index = self.next_type_index; - self.next_type_index += 1; - self.functions.function(function_type_index); - self.exports.export( - function_definition.name().as_str(), - ExportKind::Func, - function_type_index, - ); - let params: Vec<(u32, ValType)> = vec![]; - // if let Some(args) = &function_definition.arguments { - // let mut arg_index = 0; - // for param in args { - // match param { - // ArgumentType::SelfReference(self_reference) => todo!(), - // ArgumentType::IgnoreArgument(ignore_argument) => todo!(), - // ArgumentType::Argument(argument) => { - // let t = match &argument.ty { - // Type::Array(type_array) => todo!(), - // Type::Simple(simple_type) => { - // simple_type.name - // } - // Type::Generic(generic_type) => todo!(), - // Type::Function(function_type) => todo!(), - // Type::QualifiedName(qualified_name) => todo!(), - // Type::Qualified(type_qualified_name) => todo!(), - // Type::Custom(identifier) => todo!(), - // }; - // } - // ArgumentType::Type(ty) => match ty { - // Type::Array(type_array) => todo!(), - // Type::Simple(simple_type) => todo!(), - // Type::Generic(generic_type) => todo!(), - // Type::Function(function_type) => todo!(), - // Type::QualifiedName(qualified_name) => todo!(), - // Type::Qualified(type_qualified_name) => todo!(), - // Type::Custom(identifier) => todo!(), - // }, - // } - // } - // } - let mut results: Vec = vec![]; - if let Some(ret_type) = &function_definition.returns { - match ret_type { - Type::Array(type_array) => todo!(), - Type::Simple(simple_type) => { - if simple_type.type_info.borrow().is_some() { - } else { - match simple_type.name.to_lowercase().as_str() { - "i32" => results.push(ValType::I32), - "i64" => results.push(ValType::I64), - "f32" => results.push(ValType::F32), - "f64" => results.push(ValType::F64), - _ => {} - } - } - } - Type::Generic(generic_type) => todo!(), - Type::Function(function_type) => todo!(), - Type::QualifiedName(qualified_name) => todo!(), - Type::Qualified(type_qualified_name) => todo!(), - Type::Custom(identifier) => todo!(), - } - } - self.types - .ty() - .function(params.iter().map(|p| p.1), results); - let mut function = Function::new(params); - for stmt in function_definition.body.statements() { - match stmt { - Statement::Block(block_type) => todo!(), - Statement::Expression(expression) => todo!(), - Statement::Assign(assign_statement) => todo!(), - Statement::Return(return_statement) => { - match &*return_statement.expression.borrow() { - Expression::ArrayIndexAccess(array_index_access_expression) => todo!(), - Expression::Binary(binary_expression) => todo!(), - Expression::MemberAccess(member_access_expression) => todo!(), - Expression::TypeMemberAccess(type_member_access_expression) => todo!(), - Expression::FunctionCall(function_call_expression) => todo!(), - Expression::Struct(struct_expression) => todo!(), - Expression::PrefixUnary(prefix_unary_expression) => todo!(), - Expression::Parenthesized(parenthesized_expression) => todo!(), - Expression::Literal(literal) => match literal { - Literal::Array(array_literal) => todo!(), - Literal::Bool(bool_literal) => todo!(), - Literal::String(string_literal) => todo!(), - Literal::Number(number_literal) => { - function.instruction(&wasm_encoder::Instruction::I32Const( - number_literal.value.parse::().unwrap_or(0), - )); - } - Literal::Unit(unit_literal) => todo!(), - }, - Expression::Identifier(identifier) => todo!(), - Expression::Type(_) => todo!(), - Expression::Uzumaki(uzumaki_expression) => todo!(), - } - // function.instruction(&wasm_encoder::Instruction::Return); - } - Statement::Loop(loop_statement) => todo!(), - Statement::Break(break_statement) => todo!(), - Statement::If(if_statement) => todo!(), - Statement::VariableDefinition(variable_definition_statement) => todo!(), - Statement::TypeDefinition(type_definition_statement) => todo!(), - Statement::Assert(assert_statement) => todo!(), - Statement::ConstantDefinition(constant_definition) => todo!(), - } - } - function.instruction(&wasm_encoder::Instruction::End); - self.codes.function(&function); - function_type_index - } - - #[must_use] - pub fn finish(mut self) -> Vec { - // Sections must be appended in canonical order: - self.module.section(&self.types); - self.module.section(&self.functions); - self.module.section(&self.exports); - self.module.section(&self.codes); - self.module.finish() - } -} - -#[cfg(test)] -mod tests { - use wasm_encoder::{Function, ValType}; - use wasmtime::{Caller, Engine, Store}; - - use super::*; - - #[test] - fn test_wasm_module_builder() -> anyhow::Result<()> { - let mut module = Module::new(); - let params = vec![]; - let results = vec![ValType::I32]; - - let mut types = TypeSection::new(); - types.ty().function(params, results); - module.section(&types); - - let mut functions = FunctionSection::new(); - let type_index = 0; - functions.function(type_index); - module.section(&functions); - - let mut exports = ExportSection::new(); - exports.export("helloWorld", ExportKind::Func, 0); - module.section(&exports); - - let mut codes = CodeSection::new(); - let mut f = Function::new(vec![]); - f.instructions().i32_const(42).end(); - codes.function(&f); - module.section(&codes); - - let wasm_bytes = module.finish(); - - let engine = Engine::default(); - let module = wasmtime::Module::new(&engine, &wasm_bytes).unwrap(); - let mut linker = wasmtime::Linker::new(&engine); - linker.func_wrap( - "host", - "host_func", - |caller: Caller<'_, i32>, param: i32| { - println!("Got {param} from WebAssembly"); - println!("my host state is: {}", caller.data()); - }, - )?; - let mut store = Store::new(&engine, 4); - let instance = linker.instantiate(&mut store, &module)?; - let hello = instance.get_typed_func::<(), i32>(&mut store, "helloWorld")?; - let result = hello.call(&mut store, ())?; - assert_eq!(result, 42); - Ok(()) - } - - #[test] - fn test_module_bitwise_reproducable() { - let mut previous: Option> = None; - for _ in 0..10 { - let mut module = Module::new(); - let params = vec![]; - let results = vec![ValType::I32]; - - let mut types = TypeSection::new(); - types.ty().function(params, results); - module.section(&types); - - let mut functions = FunctionSection::new(); - let type_index = 0; - functions.function(type_index); - module.section(&functions); - - let mut exports = ExportSection::new(); - exports.export("helloWorld", ExportKind::Func, 0); - module.section(&exports); - - let mut codes = CodeSection::new(); - let mut f = Function::new(vec![]); - f.instructions().i32_const(42).end(); - codes.function(&f); - module.section(&codes); - - let wasm_bytes = module.finish(); - if let Some(prev) = previous { - assert_eq!(prev.len(), wasm_bytes.len()); - for (b1, b2) in prev.iter().zip(wasm_bytes.iter()) { - assert_eq!(b1, b2); - } - } - previous = Some(wasm_bytes); - } - } -} diff --git a/core/wasm-codegen/src/utils.rs b/core/wasm-codegen/src/utils.rs new file mode 100644 index 00000000..d1a2866a --- /dev/null +++ b/core/wasm-codegen/src/utils.rs @@ -0,0 +1,115 @@ +use std::{path::PathBuf, process::Command}; + +use inkwell::{module::Module, targets::TargetTriple}; +use tempfile::tempdir; + +pub(crate) fn compile_to_wasm( + module: &Module, + output_fname: &str, + optimization_level: u32, +) -> anyhow::Result> { + let llc_path = get_inf_llc_path()?; + let temp_dir = tempdir()?; + let obj_path = temp_dir.path().join(output_fname).with_extension("o"); + let ir_path = temp_dir.path().join(output_fname).with_extension("ll"); + let triple = TargetTriple::create("wasm32-unknown-unknown"); + module.set_triple(&triple); + let ir_str = module.print_to_string().to_string(); + std::fs::write(&ir_path, ir_str)?; + let opt_flag = format!("-O{}", optimization_level.min(3)); + let output = Command::new(&llc_path) + // .arg("-march=wasm32") // same as triple + .arg("-mcpu=mvp") + // .arg("-mattr=+mutable-globals") // https://doc.rust-lang.org/beta/rustc/platform-support/wasm32v1-none.html + .arg("-filetype=obj") + .arg(&ir_path) + .arg(&opt_flag) + .arg("-o") + .arg(&obj_path) + .output()?; + + if !output.status.success() { + return Err(anyhow::anyhow!( + "inf-llc failed with status: {}\nstderr: {}", + output.status, + String::from_utf8_lossy(&output.stderr) + )); + } + let wasm_ld_path = get_rust_lld_path()?; + let wasm_path = temp_dir.path().join(output_fname).with_extension("wasm"); + let wasm_ld_output = Command::new(&wasm_ld_path) + .arg("-flavor") + .arg("wasm") + .arg(&obj_path) + .arg("--no-entry") + // .arg("--export=hello_world") + .arg("-o") + .arg(&wasm_path) + .output()?; + + if !wasm_ld_output.status.success() { + return Err(anyhow::anyhow!( + "wasm-ld failed with status: {}\nstderr: {}", + wasm_ld_output.status, + String::from_utf8_lossy(&wasm_ld_output.stderr) + )); + } + + let wasm_bytes = std::fs::read(&wasm_path)?; + std::fs::remove_file(obj_path)?; + Ok(wasm_bytes) +} + +pub(crate) fn get_inf_llc_path() -> anyhow::Result { + get_bin_path( + "inf-llc", + "This package requires LLVM with Inference intrinsics support.", + ) +} + +pub(crate) fn get_rust_lld_path() -> anyhow::Result { + get_bin_path( + "rust-lld", + "This package requires rust-lld to link WebAssembly modules.", + ) +} + +fn get_bin_path(bin_name: &str, not_found_message: &str) -> anyhow::Result { + let exe_suffix = std::env::consts::EXE_SUFFIX; + let llc_name = format!("{bin_name}{exe_suffix}"); + + let exe_path = std::env::current_exe() + .map_err(|e| anyhow::anyhow!("Failed to get current executable path: {e}"))?; + + let exe_dir = exe_path + .parent() + .ok_or_else(|| anyhow::anyhow!("Failed to get executable directory"))?; + + // Try multiple possible locations: + // 1. For regular binaries: /bin/llc + // 2. For test binaries in deps/: /../bin/llc + let candidates = vec![ + exe_dir.join("bin").join(&llc_name), // target/debug/bin/llc or target/release/bin/llc + exe_dir.parent().map_or_else( + || exe_dir.join("bin").join(&llc_name), + |p| p.join("bin").join(&llc_name), // target/debug/bin/llc when exe is in target/debug/deps/ + ), + ]; + + for llc_path in &candidates { + if llc_path.exists() { + return Ok(llc_path.clone()); + } + } + + Err(anyhow::anyhow!( + "🚫 {bin_name} binary not found\n\ + \n\ + {not_found_message}\n\n\ + Executable: {}\n\ + Searched locations:\n - {}\n - {}", + exe_path.display(), + candidates[0].display(), + candidates[1].display() + )) +} diff --git a/tests/Cargo.toml b/tests/Cargo.toml index d16471cb..5c82092e 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -10,8 +10,9 @@ repository = { workspace = true } inference-ast.workspace = true inference-wasm-codegen.workspace = true inference.workspace = true +inf-wasmparser.workspace = true tree-sitter.workspace = true tree-sitter-inference.workspace = true anyhow.workspace = true serde_json = "1.0.99" - +wasmtime="39.0.1" diff --git a/tests/src/codegen/wasm/base.rs b/tests/src/codegen/wasm/base.rs index 24e5b68a..d786e9b2 100644 --- a/tests/src/codegen/wasm/base.rs +++ b/tests/src/codegen/wasm/base.rs @@ -10,10 +10,86 @@ mod base_codegen_tests { let test_file_path = get_test_file_path(module_path!(), test_name); let source_code = std::fs::read_to_string(&test_file_path) .unwrap_or_else(|_| panic!("Failed to read test file: {test_file_path:?}")); + let actual = wasm_codegen(&source_code); let expected = get_test_wasm_path(module_path!(), test_name); let expected = std::fs::read(&expected) .unwrap_or_else(|_| panic!("Failed to read expected wasm file for test: {test_name}")); + // let test_dir = std::path::Path::new(&test_file_path).parent().unwrap(); + // std::fs::write(test_dir.join("actual.wasm"), &actual) + // .unwrap_or_else(|e| panic!("Failed to write actual.wasm: {}", e)); + assert_wasms_modules_equivalence(&expected, &actual); + } + + #[test] + fn const_test() { + let test_name = "const"; + let test_file_path = get_test_file_path(module_path!(), test_name); + let source_code = std::fs::read_to_string(&test_file_path) + .unwrap_or_else(|_| panic!("Failed to read test file: {test_file_path:?}")); let actual = wasm_codegen(&source_code); + let expected = get_test_wasm_path(module_path!(), test_name); + let expected = std::fs::read(&expected) + .unwrap_or_else(|_| panic!("Failed to read expected wasm file for test: {test_name}")); + // let test_dir = std::path::Path::new(&test_file_path).parent().unwrap(); + // std::fs::write(test_dir.join("actual-const.wasm"), &actual) + // .unwrap_or_else(|e| panic!("Failed to write actual-const.wasm: {}", e)); + assert_wasms_modules_equivalence(&expected, &actual); + } + + #[test] + fn trivial_test_execution() { + use wasmtime::{Engine, Linker, Memory, MemoryType, Module, Store, TypedFunc}; + + let test_name = "trivial"; + let test_file_path = get_test_file_path(module_path!(), test_name); + let source_code = std::fs::read_to_string(&test_file_path) + .unwrap_or_else(|_| panic!("Failed to read test file: {test_file_path:?}")); + let wasm_bytes = wasm_codegen(&source_code); + + let engine = Engine::default(); + let module = Module::new(&engine, &wasm_bytes) + .unwrap_or_else(|e| panic!("Failed to create Wasm module: {}", e)); + + let mut store = Store::new(&engine, ()); + + let mut linker = Linker::new(&engine); + let memory_type = MemoryType::new(1, None); + let memory = Memory::new(&mut store, memory_type) + .unwrap_or_else(|e| panic!("Failed to create memory: {}", e)); + linker + .define(&mut store, "env", "__linear_memory", memory) + .unwrap_or_else(|e| panic!("Failed to define memory import: {}", e)); + + let instance = linker + .instantiate(&mut store, &module) + .unwrap_or_else(|e| panic!("Failed to instantiate Wasm module: {}", e)); + + let hello_world_func: TypedFunc<(), i32> = instance + .get_typed_func(&mut store, "hello_world") + .unwrap_or_else(|e| panic!("Failed to get 'hello_world' function: {}", e)); + + let result = hello_world_func + .call(&mut store, ()) + .unwrap_or_else(|e| panic!("Failed to execute 'hello_world' function: {}", e)); + + assert_eq!(result, 42, "Expected 'hello_world' function to return 42"); + } + + #[test] + fn nondet_test() { + let test_name = "nondet"; + let test_file_path = get_test_file_path(module_path!(), test_name); + let source_code = std::fs::read_to_string(&test_file_path) + .unwrap_or_else(|_| panic!("Failed to read test file: {test_file_path:?}")); + let actual = wasm_codegen(&source_code); + inf_wasmparser::validate(&actual) + .unwrap_or_else(|e| panic!("Generated Wasm module is invalid: {}", e)); + let expected = get_test_wasm_path(module_path!(), test_name); + let expected = std::fs::read(&expected) + .unwrap_or_else(|_| panic!("Failed to read expected wasm file for test: {test_name}")); + // let test_dir = std::path::Path::new(&test_file_path).parent().unwrap(); + // std::fs::write(test_dir.join("actual-nondet.wasm"), &actual) + // .unwrap_or_else(|e| panic!("Failed to write actual-nondet.wasm: {}", e)); assert_wasms_modules_equivalence(&expected, &actual); } } diff --git a/tests/test_data/codegen/wasm/base/const.inf b/tests/test_data/codegen/wasm/base/const.inf new file mode 100644 index 00000000..c2a1be57 --- /dev/null +++ b/tests/test_data/codegen/wasm/base/const.inf @@ -0,0 +1,4 @@ +fn hello_const_i32() -> i32 { + const a: i32 = 42; + return a; +} diff --git a/tests/test_data/codegen/wasm/base/const.wasm b/tests/test_data/codegen/wasm/base/const.wasm new file mode 100644 index 00000000..06932da6 Binary files /dev/null and b/tests/test_data/codegen/wasm/base/const.wasm differ diff --git a/tests/test_data/codegen/wasm/base/nondet.inf b/tests/test_data/codegen/wasm/base/nondet.inf new file mode 100644 index 00000000..d76ebf12 --- /dev/null +++ b/tests/test_data/codegen/wasm/base/nondet.inf @@ -0,0 +1,27 @@ +fn hello_uzumaki() -> i32 { + return @; +} + +fn hello_world() { + forall { + const a: i32 = 42; + } +} + +fn hello_exists() { + exists { + const a: i32 = 42; + } +} + +fn hello_assume() { + assume { + const a: i32 = 42; + } +} + +fn hello_unique() { + unique { + const a: i32 = 42; + } +} diff --git a/tests/test_data/codegen/wasm/base/nondet.wasm b/tests/test_data/codegen/wasm/base/nondet.wasm new file mode 100644 index 00000000..63c6b203 Binary files /dev/null and b/tests/test_data/codegen/wasm/base/nondet.wasm differ diff --git a/tests/test_data/codegen/wasm/base/trivial.wasm b/tests/test_data/codegen/wasm/base/trivial.wasm index dabb06d8..1538e954 100644 Binary files a/tests/test_data/codegen/wasm/base/trivial.wasm and b/tests/test_data/codegen/wasm/base/trivial.wasm differ