Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/rue-air/src/sema/declarations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,8 @@ impl<'a> Sema<'a> {
return_type: ret_type,
return_type_sym,
body,
rir_params_start: params_start,
rir_params_len: params_len,
span,
is_generic,
is_pub,
Expand Down
5 changes: 4 additions & 1 deletion crates/rue-air/src/sema/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ pub struct FunctionInfo {
pub return_type: Type,
/// The return type symbol (before resolution) - needed for generic function specialization
pub return_type_sym: Spur,
/// The RIR instruction ref for the function body - needed for generic function specialization
/// RIR body ref for generic specialization
pub body: rue_rir::InstRef,
/// RIR params indices for type symbol lookup during specialization
pub rir_params_start: u32,
pub rir_params_len: u32,
/// Span of the function declaration
pub span: Span,
/// Whether this function has any comptime type parameters
Expand Down
138 changes: 42 additions & 96 deletions crates/rue-air/src/specialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//! The specialization pass runs after semantic analysis but before CFG building.
//! It transforms the AIR in-place and adds new specialized functions to the output.

use std::collections::hash_map::Entry;
use std::collections::HashMap;

use lasso::{Spur, ThreadedRodeo};
Expand Down Expand Up @@ -62,15 +63,9 @@ pub fn specialize(
return Ok(());
}

// Build a map from key to just the mangled name for the rewrite phase
let name_map: HashMap<SpecializationKey, Spur> = specializations
.iter()
.map(|(k, v)| (k.clone(), v.mangled_name))
.collect();

// Phase 2: Rewrite CallGeneric to Call in all functions
for func in &mut output.functions {
rewrite_call_generic(&mut func.air, &name_map);
rewrite_call_generic(&mut func.air, &specializations);
}

// Phase 3: Create specialized function bodies by re-analyzing with type substitution
Expand Down Expand Up @@ -99,7 +94,6 @@ pub fn specialize(
Ok(())
}

/// Collect all specializations needed from a function's AIR.
fn collect_specializations(
air: &Air,
interner: &ThreadedRodeo,
Expand All @@ -113,7 +107,6 @@ fn collect_specializations(
..
} = &inst.data
{
// Extract type arguments using the public accessor
let type_args: Vec<Type> = air
.get_extra(*type_args_start, *type_args_len)
.iter()
Expand All @@ -122,30 +115,26 @@ fn collect_specializations(

let key = SpecializationKey {
base_name: *name,
type_args: type_args.clone(),
type_args,
};

if !specializations.contains_key(&key) {
// Generate a mangled name for the specialized function
if let Entry::Vacant(entry) = specializations.entry(key) {
let base_name = interner.resolve(name);
let mangled = mangle_specialized_name(base_name, &type_args);
let mangled = mangle_specialized_name(base_name, &entry.key().type_args);
let mangled_sym = interner.get_or_intern(&mangled);
specializations.insert(
key,
SpecializationInfo {
mangled_name: mangled_sym,
call_site_span: inst.span,
},
);
entry.insert(SpecializationInfo {
mangled_name: mangled_sym,
call_site_span: inst.span,
});
}
}
}
}

/// Rewrite CallGeneric instructions to Call instructions.
fn rewrite_call_generic(air: &mut Air, specializations: &HashMap<SpecializationKey, Spur>) {
// We need to collect the rewrites first, then apply them.
// This avoids borrowing issues with the extra array.
fn rewrite_call_generic(
air: &mut Air,
specializations: &HashMap<SpecializationKey, SpecializationInfo>,
) {
let mut rewrites: Vec<(usize, AirInstData)> = Vec::new();

for (i, inst) in air.instructions().iter().enumerate() {
Expand All @@ -157,7 +146,6 @@ fn rewrite_call_generic(air: &mut Air, specializations: &HashMap<SpecializationK
args_len,
} = &inst.data
{
// Extract type arguments to form the key
let type_args: Vec<Type> = air
.get_extra(*type_args_start, *type_args_len)
.iter()
Expand All @@ -169,19 +157,19 @@ fn rewrite_call_generic(air: &mut Air, specializations: &HashMap<SpecializationK
type_args,
};

if let Some(&specialized_name) = specializations.get(&key) {
// Rewrite to a regular Call
let new_data = AirInstData::Call {
name: specialized_name,
args_start: *args_start,
args_len: *args_len,
};
rewrites.push((i, new_data));
if let Some(info) = specializations.get(&key) {
rewrites.push((
i,
AirInstData::Call {
name: info.mangled_name,
args_start: *args_start,
args_len: *args_len,
},
));
}
}
}

// Apply all rewrites
for (index, new_data) in rewrites {
air.rewrite_inst_data(index, new_data);
}
Expand Down Expand Up @@ -211,27 +199,16 @@ fn create_specialized_function(
) -> CompileResult<AnalyzedFunction> {
let specialized_name_str = interner.resolve(&specialized_name).to_string();

// Get parameter data from the arena
let param_names = sema.param_arena.names(base_info.params);
let param_types = sema.param_arena.types(base_info.params);
let param_modes = sema.param_arena.modes(base_info.params);
let param_comptime = sema.param_arena.comptime(base_info.params);

// Build the type substitution map: comptime param name -> concrete Type
let mut type_subst: HashMap<Spur, Type> = HashMap::new();
let mut type_arg_idx = 0;
for (i, is_comptime) in param_comptime.iter().enumerate() {
if *is_comptime {
if type_arg_idx < key.type_args.len() {
type_subst.insert(param_names[i], key.type_args[type_arg_idx]);
type_arg_idx += 1;
}
for (name, _, _, is_comptime) in sema.param_arena.iter(base_info.params) {
if *is_comptime && type_arg_idx < key.type_args.len() {
type_subst.insert(*name, key.type_args[type_arg_idx]);
type_arg_idx += 1;
}
}

// Calculate the return type by substituting type parameters
let return_type = if base_info.return_type == Type::COMPTIME_TYPE {
// The return type references a type parameter - substitute it
type_subst
.get(&base_info.return_type_sym)
.copied()
Expand All @@ -240,34 +217,23 @@ fn create_specialized_function(
base_info.return_type
};

// Build the specialized parameter list by:
// 1. Filtering out comptime parameters (they're erased at runtime)
// 2. Substituting type parameters in non-comptime parameter types
let specialized_params: Vec<(Spur, Type, RirParamMode)> = param_names
.iter()
.zip(param_types.iter())
.zip(param_modes.iter())
.zip(param_comptime.iter())
.filter(|(((_, _), _), is_comptime)| !*is_comptime)
.map(|(((name, ty), mode), _)| {
// If the type is ComptimeType, look it up in the substitution map
// The param name's type symbol is stored in param_types as ComptimeType,
// but we need to find which type param it references.
// For now, we'll need to look at the original RIR to get the type name.
let specialized_params: Vec<(Spur, Type, RirParamMode)> = sema
.param_arena
.iter(base_info.params)
.filter(|(_, _, _, is_comptime)| !**is_comptime)
.map(|(name, ty, mode, _)| {
let concrete_ty = if *ty == Type::COMPTIME_TYPE {
// This parameter's type is a type parameter. We need to find which one.
// The type name in RIR is stored in the param's ty field as a Spur.
// Unfortunately, we've lost that information by this point.
// We need to look at the original function in RIR.
substitute_param_type(sema, base_info, *name, &type_subst).unwrap_or(*ty)
substitute_param_type(sema, base_info, *name, &type_subst).unwrap_or_else(|| {
debug_assert!(false, "type substitution failed for param");
*ty
})
} else {
*ty
};
(*name, concrete_ty, *mode)
})
.collect();

// Now analyze the function body with the specialized types
let (
air,
num_locals,
Expand All @@ -294,40 +260,20 @@ fn create_specialized_function(
})
}

/// Substitute a parameter's type using the type substitution map.
///
/// This looks up the parameter's type symbol in the original RIR function
/// and substitutes it with the concrete type if it's a type parameter.
/// Look up a parameter's concrete type from the type substitution map.
fn substitute_param_type(
sema: &Sema<'_>,
base_info: &FunctionInfo,
param_name: Spur,
type_subst: &HashMap<Spur, Type>,
) -> Option<Type> {
// Walk up to find the FnDecl that contains this body
for (_, inst) in sema.rir.iter() {
if let rue_rir::InstData::FnDecl {
body,
params_start,
params_len,
..
} = &inst.data
{
if *body == base_info.body {
// Found the function declaration
let params = sema.rir.get_params(*params_start, *params_len);
for param in params {
if param.name == param_name {
// Found the parameter - get its type symbol
// If the type symbol is in our substitution map, use that
if let Some(&concrete_ty) = type_subst.get(&param.ty) {
return Some(concrete_ty);
}
}
}
}
let params = sema
.rir
.get_params(base_info.rir_params_start, base_info.rir_params_len);
for param in params {
if param.name == param_name {
return type_subst.get(&param.ty).copied();
}
}

None
}