diff --git a/crates/herkos-core/src/optimizer/branch_fold.rs b/crates/herkos-core/src/optimizer/branch_fold.rs new file mode 100644 index 0000000..488bf7b --- /dev/null +++ b/crates/herkos-core/src/optimizer/branch_fold.rs @@ -0,0 +1,334 @@ +//! Branch condition folding. +//! +//! Simplifies `BranchIf` terminators by looking at the instruction that +//! defines the condition variable: +//! +//! - `Eqz(x)` as condition → swap branch targets, use `x` directly +//! - `Ne(x, 0)` as condition → use `x` directly +//! - `Eq(x, 0)` as condition → swap branch targets, use `x` directly +//! +//! After substitution, the defining instruction becomes dead (single use was +//! the branch) and is cleaned up by `dead_instrs`. + +use super::utils::{build_global_const_map, build_global_use_count, instr_dest, is_zero}; +use crate::ir::{BinOp, IrFunction, IrInstr, IrTerminator, IrValue, UnOp, VarId}; +use std::collections::HashMap; + +pub fn eliminate(func: &mut IrFunction) { + loop { + let global_uses = build_global_use_count(func); + let global_consts = build_global_const_map(func); + if !fold_one(func, &global_uses, &global_consts) { + break; + } + } +} + +/// Attempt a single branch fold across the function. Returns `true` if a +/// change was made. +fn fold_one( + func: &mut IrFunction, + global_uses: &HashMap, + global_consts: &HashMap, +) -> bool { + // Build a map of VarId → defining instruction info. + // We only care about single-use vars defined by Eqz, Ne(x,0), or Eq(x,0). + let mut var_defs: HashMap = HashMap::new(); + + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + match instr { + IrInstr::UnOp { + op: UnOp::I32Eqz | UnOp::I64Eqz, + operand, + .. + } => { + var_defs.insert(dest, VarDef::Eqz(*operand)); + } + IrInstr::BinOp { + op: BinOp::I32Ne | BinOp::I64Ne, + lhs, + rhs, + .. + } => { + if is_zero(*rhs, global_consts) { + var_defs.insert(dest, VarDef::NeZero(*lhs)); + } else if is_zero(*lhs, global_consts) { + var_defs.insert(dest, VarDef::NeZero(*rhs)); + } + } + IrInstr::BinOp { + op: BinOp::I32Eq | BinOp::I64Eq, + lhs, + rhs, + .. + } => { + if is_zero(*rhs, global_consts) { + var_defs.insert(dest, VarDef::EqZero(*lhs)); + } else if is_zero(*lhs, global_consts) { + var_defs.insert(dest, VarDef::EqZero(*rhs)); + } + } + _ => {} + } + } + } + } + + // Now scan terminators for BranchIf with a foldable condition. + for block in &mut func.blocks { + let condition = match &block.terminator { + IrTerminator::BranchIf { condition, .. } => *condition, + _ => continue, + }; + + // Only fold if the condition has exactly one use (the BranchIf). + if global_uses.get(&condition).copied().unwrap_or(0) != 1 { + continue; + } + + let def = match var_defs.get(&condition) { + Some(d) => d, + None => continue, + }; + + match def { + VarDef::Eqz(inner) | VarDef::EqZero(inner) => { + // eqz(x) != 0 ≡ x == 0, so swap targets and use x + if let IrTerminator::BranchIf { + condition: cond, + if_true, + if_false, + } = &mut block.terminator + { + *cond = *inner; + std::mem::swap(if_true, if_false); + } + return true; + } + VarDef::NeZero(inner) => { + // ne(x, 0) != 0 ≡ x != 0, so just use x + if let IrTerminator::BranchIf { + condition: cond, .. + } = &mut block.terminator + { + *cond = *inner; + } + return true; + } + } + } + + false +} + +#[derive(Clone, Copy)] +enum VarDef { + Eqz(VarId), + NeZero(VarId), + EqZero(VarId), +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, TypeIdx}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + #[test] + fn eqz_swaps_targets() { + // v1 = Eqz(v0); BranchIf(v1, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn ne_zero_simplifies() { + // v1 = 0; v2 = Ne(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B1, B2) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(1), "targets should NOT be swapped"); + assert_eq!(*if_false, BlockId(2)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn eq_zero_swaps() { + // v1 = 0; v2 = Eq(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Eq, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn multi_use_not_folded() { + // v1 = Eqz(v0); use(v1) elsewhere → don't fold + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(1)), // second use of v1 + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // Should NOT fold — v1 has 2 uses + match &func.blocks[0].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(1), "should not have been folded"); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn cross_block_zero_const() { + // B0: v1 = 0; Jump(B1) + // B1: v2 = Ne(v0, v1); BranchIf(v2, B2, B3) → BranchIf(v0, B2, B3) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + ]); + eliminate(&mut func); + match &func.blocks[1].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(0)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } +} diff --git a/crates/herkos-core/src/optimizer/mod.rs b/crates/herkos-core/src/optimizer/mod.rs index 06e0d84..33936f0 100644 --- a/crates/herkos-core/src/optimizer/mod.rs +++ b/crates/herkos-core/src/optimizer/mod.rs @@ -20,6 +20,7 @@ mod copy_prop; mod dead_blocks; // ── Post-lowering passes ───────────────────────────────────────────────────── +mod branch_fold; mod dead_instrs; mod empty_blocks; mod merge_blocks; @@ -46,16 +47,11 @@ pub fn optimize_ir(module_info: ModuleInfo, do_opt: bool) -> Result /// Optimizes the lowered IR after phi nodes have been eliminated. /// -/// Passes here operate on [`LoweredModuleInfo`] where all `IrInstr::Phi` -/// nodes have been replaced by `IrInstr::Assign` in predecessor blocks. -/// -/// ## Structural and copy passes -/// -/// Run multiple iterations of structural cleanup + copy propagation. +/// Runs post-lowering structural passes and branch condition folding. /// dead_instrs may leave empty blocks, which empty_blocks and merge_blocks then -/// eliminate, potentially exposing new dead instructions. copy_prop forwards the -/// assignments that lower_phis inserted. We repeat until reaching a fixed point -/// (typically 2 iterations). +/// eliminate, potentially exposing new dead instructions. branch_fold simplifies +/// `BranchIf` terminators whose condition is a known comparison. We repeat until +/// reaching a fixed point (typically 2 iterations). pub fn optimize_lowered_ir( module_info: LoweredModuleInfo, do_opt: bool, @@ -63,8 +59,6 @@ pub fn optimize_lowered_ir( let mut module_info = module_info; if do_opt { for func in &mut module_info.ir_functions { - // Two passes: dead_instrs may create empty blocks, and copy_prop - // may reveal new dead instrs. for _ in 0..2 { empty_blocks::eliminate(func); dead_blocks::eliminate(func)?; @@ -72,6 +66,8 @@ pub fn optimize_lowered_ir( dead_blocks::eliminate(func)?; copy_prop::eliminate(func); dead_instrs::eliminate(func); + branch_fold::eliminate(func); + dead_instrs::eliminate(func); } } } diff --git a/crates/herkos-core/src/optimizer/utils.rs b/crates/herkos-core/src/optimizer/utils.rs index dcd385f..52c82a7 100644 --- a/crates/herkos-core/src/optimizer/utils.rs +++ b/crates/herkos-core/src/optimizer/utils.rs @@ -495,6 +495,14 @@ pub fn rewrite_terminator_target(term: &mut IrTerminator, old: BlockId, new: Blo } } +/// Returns `true` if `var` is known to be zero according to `consts`. +pub fn is_zero(var: VarId, consts: &HashMap) -> bool { + matches!( + consts.get(&var), + Some(IrValue::I32(0)) | Some(IrValue::I64(0)) + ) +} + /// Variables with exactly one definition across the function that is a `Const` /// instruction. These can be treated as constants in any block that uses them. pub fn build_global_const_map(func: &IrFunction) -> HashMap {