diff --git a/crates/rue-air/src/sema/declarations.rs b/crates/rue-air/src/sema/declarations.rs index 6201fffe..1bb4d34c 100644 --- a/crates/rue-air/src/sema/declarations.rs +++ b/crates/rue-air/src/sema/declarations.rs @@ -777,6 +777,8 @@ impl<'a> Sema<'a> { return_type: ret_type, return_type_sym, body, + rir_params_start: params_start, + rir_params_len: params_len, span, is_generic, is_pub, diff --git a/crates/rue-air/src/sema/info.rs b/crates/rue-air/src/sema/info.rs index 5b29f276..9b4c2507 100644 --- a/crates/rue-air/src/sema/info.rs +++ b/crates/rue-air/src/sema/info.rs @@ -20,8 +20,11 @@ pub struct FunctionInfo { pub return_type: Type, /// The return type symbol (before resolution) - needed for generic function specialization pub return_type_sym: Spur, - /// The RIR instruction ref for the function body - needed for generic function specialization + /// RIR body ref for generic specialization pub body: rue_rir::InstRef, + /// RIR params indices for type symbol lookup during specialization + pub rir_params_start: u32, + pub rir_params_len: u32, /// Span of the function declaration pub span: Span, /// Whether this function has any comptime type parameters diff --git a/crates/rue-air/src/specialize.rs b/crates/rue-air/src/specialize.rs index 9bb52683..089c404d 100644 --- a/crates/rue-air/src/specialize.rs +++ b/crates/rue-air/src/specialize.rs @@ -12,6 +12,7 @@ //! The specialization pass runs after semantic analysis but before CFG building. //! It transforms the AIR in-place and adds new specialized functions to the output. +use std::collections::hash_map::Entry; use std::collections::HashMap; use lasso::{Spur, ThreadedRodeo}; @@ -62,15 +63,9 @@ pub fn specialize( return Ok(()); } - // Build a map from key to just the mangled name for the rewrite phase - let name_map: HashMap = specializations - .iter() - .map(|(k, v)| (k.clone(), v.mangled_name)) - .collect(); - // Phase 2: Rewrite CallGeneric to Call in all functions for func in &mut output.functions { - rewrite_call_generic(&mut func.air, &name_map); + rewrite_call_generic(&mut func.air, &specializations); } // Phase 3: Create specialized function bodies by re-analyzing with type substitution @@ -99,7 +94,6 @@ pub fn specialize( Ok(()) } -/// Collect all specializations needed from a function's AIR. fn collect_specializations( air: &Air, interner: &ThreadedRodeo, @@ -113,7 +107,6 @@ fn collect_specializations( .. } = &inst.data { - // Extract type arguments using the public accessor let type_args: Vec = air .get_extra(*type_args_start, *type_args_len) .iter() @@ -122,30 +115,26 @@ fn collect_specializations( let key = SpecializationKey { base_name: *name, - type_args: type_args.clone(), + type_args, }; - if !specializations.contains_key(&key) { - // Generate a mangled name for the specialized function + if let Entry::Vacant(entry) = specializations.entry(key) { let base_name = interner.resolve(name); - let mangled = mangle_specialized_name(base_name, &type_args); + let mangled = mangle_specialized_name(base_name, &entry.key().type_args); let mangled_sym = interner.get_or_intern(&mangled); - specializations.insert( - key, - SpecializationInfo { - mangled_name: mangled_sym, - call_site_span: inst.span, - }, - ); + entry.insert(SpecializationInfo { + mangled_name: mangled_sym, + call_site_span: inst.span, + }); } } } } -/// Rewrite CallGeneric instructions to Call instructions. -fn rewrite_call_generic(air: &mut Air, specializations: &HashMap) { - // We need to collect the rewrites first, then apply them. - // This avoids borrowing issues with the extra array. +fn rewrite_call_generic( + air: &mut Air, + specializations: &HashMap, +) { let mut rewrites: Vec<(usize, AirInstData)> = Vec::new(); for (i, inst) in air.instructions().iter().enumerate() { @@ -157,7 +146,6 @@ fn rewrite_call_generic(air: &mut Air, specializations: &HashMap = air .get_extra(*type_args_start, *type_args_len) .iter() @@ -169,19 +157,19 @@ fn rewrite_call_generic(air: &mut Air, specializations: &HashMap CompileResult { let specialized_name_str = interner.resolve(&specialized_name).to_string(); - // Get parameter data from the arena - let param_names = sema.param_arena.names(base_info.params); - let param_types = sema.param_arena.types(base_info.params); - let param_modes = sema.param_arena.modes(base_info.params); - let param_comptime = sema.param_arena.comptime(base_info.params); - - // Build the type substitution map: comptime param name -> concrete Type let mut type_subst: HashMap = HashMap::new(); let mut type_arg_idx = 0; - for (i, is_comptime) in param_comptime.iter().enumerate() { - if *is_comptime { - if type_arg_idx < key.type_args.len() { - type_subst.insert(param_names[i], key.type_args[type_arg_idx]); - type_arg_idx += 1; - } + for (name, _, _, is_comptime) in sema.param_arena.iter(base_info.params) { + if *is_comptime && type_arg_idx < key.type_args.len() { + type_subst.insert(*name, key.type_args[type_arg_idx]); + type_arg_idx += 1; } } - // Calculate the return type by substituting type parameters let return_type = if base_info.return_type == Type::COMPTIME_TYPE { - // The return type references a type parameter - substitute it type_subst .get(&base_info.return_type_sym) .copied() @@ -240,26 +217,16 @@ fn create_specialized_function( base_info.return_type }; - // Build the specialized parameter list by: - // 1. Filtering out comptime parameters (they're erased at runtime) - // 2. Substituting type parameters in non-comptime parameter types - let specialized_params: Vec<(Spur, Type, RirParamMode)> = param_names - .iter() - .zip(param_types.iter()) - .zip(param_modes.iter()) - .zip(param_comptime.iter()) - .filter(|(((_, _), _), is_comptime)| !*is_comptime) - .map(|(((name, ty), mode), _)| { - // If the type is ComptimeType, look it up in the substitution map - // The param name's type symbol is stored in param_types as ComptimeType, - // but we need to find which type param it references. - // For now, we'll need to look at the original RIR to get the type name. + let specialized_params: Vec<(Spur, Type, RirParamMode)> = sema + .param_arena + .iter(base_info.params) + .filter(|(_, _, _, is_comptime)| !**is_comptime) + .map(|(name, ty, mode, _)| { let concrete_ty = if *ty == Type::COMPTIME_TYPE { - // This parameter's type is a type parameter. We need to find which one. - // The type name in RIR is stored in the param's ty field as a Spur. - // Unfortunately, we've lost that information by this point. - // We need to look at the original function in RIR. - substitute_param_type(sema, base_info, *name, &type_subst).unwrap_or(*ty) + substitute_param_type(sema, base_info, *name, &type_subst).unwrap_or_else(|| { + debug_assert!(false, "type substitution failed for param"); + *ty + }) } else { *ty }; @@ -267,7 +234,6 @@ fn create_specialized_function( }) .collect(); - // Now analyze the function body with the specialized types let ( air, num_locals, @@ -294,40 +260,20 @@ fn create_specialized_function( }) } -/// Substitute a parameter's type using the type substitution map. -/// -/// This looks up the parameter's type symbol in the original RIR function -/// and substitutes it with the concrete type if it's a type parameter. +/// Look up a parameter's concrete type from the type substitution map. fn substitute_param_type( sema: &Sema<'_>, base_info: &FunctionInfo, param_name: Spur, type_subst: &HashMap, ) -> Option { - // Walk up to find the FnDecl that contains this body - for (_, inst) in sema.rir.iter() { - if let rue_rir::InstData::FnDecl { - body, - params_start, - params_len, - .. - } = &inst.data - { - if *body == base_info.body { - // Found the function declaration - let params = sema.rir.get_params(*params_start, *params_len); - for param in params { - if param.name == param_name { - // Found the parameter - get its type symbol - // If the type symbol is in our substitution map, use that - if let Some(&concrete_ty) = type_subst.get(¶m.ty) { - return Some(concrete_ty); - } - } - } - } + let params = sema + .rir + .get_params(base_info.rir_params_start, base_info.rir_params_len); + for param in params { + if param.name == param_name { + return type_subst.get(¶m.ty).copied(); } } - None }