From d8ad3a17622624238ec383a4811c3dffb6208ce2 Mon Sep 17 00:00:00 2001 From: bench Date: Fri, 27 Mar 2026 08:16:19 +0000 Subject: [PATCH 1/2] feat: add branch condition folding pass Adds `branch_fold` as a new post-lowering optimizer pass. Simplifies `BranchIf` terminators when the condition variable is defined by a known comparison (Eqz, Ne(x,0), Eq(x,0)), replacing the condition with its operand directly and swapping targets where needed. Dead defining instructions are cleaned up by the subsequent `dead_instrs` pass. Co-Authored-By: Claude Sonnet 4.6 --- .../herkos-core/src/optimizer/branch_fold.rs | 365 ++++++++++++++++++ crates/herkos-core/src/optimizer/mod.rs | 18 +- 2 files changed, 372 insertions(+), 11 deletions(-) create mode 100644 crates/herkos-core/src/optimizer/branch_fold.rs diff --git a/crates/herkos-core/src/optimizer/branch_fold.rs b/crates/herkos-core/src/optimizer/branch_fold.rs new file mode 100644 index 0000000..3a7f978 --- /dev/null +++ b/crates/herkos-core/src/optimizer/branch_fold.rs @@ -0,0 +1,365 @@ +//! Branch condition folding. +//! +//! Simplifies `BranchIf` terminators by looking at the instruction that +//! defines the condition variable: +//! +//! - `Eqz(x)` as condition → swap branch targets, use `x` directly +//! - `Ne(x, 0)` as condition → use `x` directly +//! - `Eq(x, 0)` as condition → swap branch targets, use `x` directly +//! +//! After substitution, the defining instruction becomes dead (single use was +//! the branch) and is cleaned up by `dead_instrs`. + +use super::utils::{build_global_use_count, instr_dest}; +use crate::ir::{BinOp, IrFunction, IrInstr, IrTerminator, IrValue, UnOp, VarId}; +use std::collections::HashMap; + +pub fn eliminate(func: &mut IrFunction) { + loop { + let global_uses = build_global_use_count(func); + if !fold_one(func, &global_uses) { + break; + } + } +} + +/// Attempt a single branch fold across the function. Returns `true` if a +/// change was made. +fn fold_one(func: &mut IrFunction, global_uses: &HashMap) -> bool { + // Build a map of VarId → defining instruction info. + // We only care about single-use vars defined by Eqz, Ne(x,0), or Eq(x,0). + let mut var_defs: HashMap = HashMap::new(); + + // Also build a global constant map for checking if an operand is zero. + let global_consts = build_global_const_map(func); + + for block in &func.blocks { + let mut local_consts = global_consts.clone(); + for instr in &block.instructions { + if let IrInstr::Const { dest, value } = instr { + local_consts.insert(*dest, *value); + } + + if let Some(dest) = instr_dest(instr) { + match instr { + IrInstr::UnOp { + op: UnOp::I32Eqz | UnOp::I64Eqz, + operand, + .. + } => { + var_defs.insert(dest, VarDef::Eqz(*operand)); + } + IrInstr::BinOp { + op: BinOp::I32Ne | BinOp::I64Ne, + lhs, + rhs, + .. + } => { + if is_zero(rhs, &local_consts) { + var_defs.insert(dest, VarDef::NeZero(*lhs)); + } else if is_zero(lhs, &local_consts) { + var_defs.insert(dest, VarDef::NeZero(*rhs)); + } + } + IrInstr::BinOp { + op: BinOp::I32Eq | BinOp::I64Eq, + lhs, + rhs, + .. + } => { + if is_zero(rhs, &local_consts) { + var_defs.insert(dest, VarDef::EqZero(*lhs)); + } else if is_zero(lhs, &local_consts) { + var_defs.insert(dest, VarDef::EqZero(*rhs)); + } + } + _ => {} + } + } + } + } + + // Now scan terminators for BranchIf with a foldable condition. + for block in &mut func.blocks { + let condition = match &block.terminator { + IrTerminator::BranchIf { condition, .. } => *condition, + _ => continue, + }; + + // Only fold if the condition has exactly one use (the BranchIf). + if global_uses.get(&condition).copied().unwrap_or(0) != 1 { + continue; + } + + let def = match var_defs.get(&condition) { + Some(d) => d, + None => continue, + }; + + match def { + VarDef::Eqz(inner) | VarDef::EqZero(inner) => { + // eqz(x) != 0 ≡ x == 0, so swap targets and use x + if let IrTerminator::BranchIf { + condition: cond, + if_true, + if_false, + } = &mut block.terminator + { + *cond = *inner; + std::mem::swap(if_true, if_false); + } + return true; + } + VarDef::NeZero(inner) => { + // ne(x, 0) != 0 ≡ x != 0, so just use x + if let IrTerminator::BranchIf { + condition: cond, .. + } = &mut block.terminator + { + *cond = *inner; + } + return true; + } + } + } + + false +} + +#[derive(Clone, Copy)] +enum VarDef { + Eqz(VarId), + NeZero(VarId), + EqZero(VarId), +} + +fn is_zero(var: &VarId, consts: &HashMap) -> bool { + matches!( + consts.get(var), + Some(IrValue::I32(0)) | Some(IrValue::I64(0)) + ) +} + +fn build_global_const_map(func: &IrFunction) -> HashMap { + let mut total_defs: HashMap = HashMap::new(); + let mut const_defs: HashMap = HashMap::new(); + + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *total_defs.entry(dest).or_insert(0) += 1; + if let IrInstr::Const { dest, value } = instr { + const_defs.insert(*dest, *value); + } + } + } + } + + const_defs + .into_iter() + .filter(|(v, _)| total_defs.get(v).copied().unwrap_or(0) == 1) + .collect() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, TypeIdx}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + #[test] + fn eqz_swaps_targets() { + // v1 = Eqz(v0); BranchIf(v1, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn ne_zero_simplifies() { + // v1 = 0; v2 = Ne(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B1, B2) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(1), "targets should NOT be swapped"); + assert_eq!(*if_false, BlockId(2)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn eq_zero_swaps() { + // v1 = 0; v2 = Eq(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Eq, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn multi_use_not_folded() { + // v1 = Eqz(v0); use(v1) elsewhere → don't fold + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(1)), // second use of v1 + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // Should NOT fold — v1 has 2 uses + match &func.blocks[0].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(1), "should not have been folded"); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn cross_block_zero_const() { + // B0: v1 = 0; Jump(B1) + // B1: v2 = Ne(v0, v1); BranchIf(v2, B2, B3) → BranchIf(v0, B2, B3) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + ]); + eliminate(&mut func); + match &func.blocks[1].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(0)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } +} diff --git a/crates/herkos-core/src/optimizer/mod.rs b/crates/herkos-core/src/optimizer/mod.rs index 06e0d84..33936f0 100644 --- a/crates/herkos-core/src/optimizer/mod.rs +++ b/crates/herkos-core/src/optimizer/mod.rs @@ -20,6 +20,7 @@ mod copy_prop; mod dead_blocks; // ── Post-lowering passes ───────────────────────────────────────────────────── +mod branch_fold; mod dead_instrs; mod empty_blocks; mod merge_blocks; @@ -46,16 +47,11 @@ pub fn optimize_ir(module_info: ModuleInfo, do_opt: bool) -> Result /// Optimizes the lowered IR after phi nodes have been eliminated. /// -/// Passes here operate on [`LoweredModuleInfo`] where all `IrInstr::Phi` -/// nodes have been replaced by `IrInstr::Assign` in predecessor blocks. -/// -/// ## Structural and copy passes -/// -/// Run multiple iterations of structural cleanup + copy propagation. +/// Runs post-lowering structural passes and branch condition folding. /// dead_instrs may leave empty blocks, which empty_blocks and merge_blocks then -/// eliminate, potentially exposing new dead instructions. copy_prop forwards the -/// assignments that lower_phis inserted. We repeat until reaching a fixed point -/// (typically 2 iterations). +/// eliminate, potentially exposing new dead instructions. branch_fold simplifies +/// `BranchIf` terminators whose condition is a known comparison. We repeat until +/// reaching a fixed point (typically 2 iterations). pub fn optimize_lowered_ir( module_info: LoweredModuleInfo, do_opt: bool, @@ -63,8 +59,6 @@ pub fn optimize_lowered_ir( let mut module_info = module_info; if do_opt { for func in &mut module_info.ir_functions { - // Two passes: dead_instrs may create empty blocks, and copy_prop - // may reveal new dead instrs. for _ in 0..2 { empty_blocks::eliminate(func); dead_blocks::eliminate(func)?; @@ -72,6 +66,8 @@ pub fn optimize_lowered_ir( dead_blocks::eliminate(func)?; copy_prop::eliminate(func); dead_instrs::eliminate(func); + branch_fold::eliminate(func); + dead_instrs::eliminate(func); } } } From edd1a96945896ddfe23a25007f5fb6a2e978f3e3 Mon Sep 17 00:00:00 2001 From: bench Date: Fri, 27 Mar 2026 08:35:33 +0000 Subject: [PATCH 2/2] refactor: move is_zero to utils, deduplicate build_global_const_map MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove local `build_global_const_map` from branch_fold — utils already provides an identical implementation; import it instead - Move `is_zero` helper to utils so it can be shared by other passes - Compute both `global_uses` and `global_consts` in `eliminate` and pass both into `fold_one`, making the two maps symmetric at the call site - Drop the per-block `local_consts` clone — `build_global_const_map` already captures all Const-defined variables across the function in SSA form, so the block-level augmentation was redundant Co-Authored-By: Claude Sonnet 4.6 --- .../herkos-core/src/optimizer/branch_fold.rs | 55 ++++--------------- crates/herkos-core/src/optimizer/utils.rs | 8 +++ 2 files changed, 20 insertions(+), 43 deletions(-) diff --git a/crates/herkos-core/src/optimizer/branch_fold.rs b/crates/herkos-core/src/optimizer/branch_fold.rs index 3a7f978..488bf7b 100644 --- a/crates/herkos-core/src/optimizer/branch_fold.rs +++ b/crates/herkos-core/src/optimizer/branch_fold.rs @@ -10,14 +10,15 @@ //! After substitution, the defining instruction becomes dead (single use was //! the branch) and is cleaned up by `dead_instrs`. -use super::utils::{build_global_use_count, instr_dest}; +use super::utils::{build_global_const_map, build_global_use_count, instr_dest, is_zero}; use crate::ir::{BinOp, IrFunction, IrInstr, IrTerminator, IrValue, UnOp, VarId}; use std::collections::HashMap; pub fn eliminate(func: &mut IrFunction) { loop { let global_uses = build_global_use_count(func); - if !fold_one(func, &global_uses) { + let global_consts = build_global_const_map(func); + if !fold_one(func, &global_uses, &global_consts) { break; } } @@ -25,21 +26,17 @@ pub fn eliminate(func: &mut IrFunction) { /// Attempt a single branch fold across the function. Returns `true` if a /// change was made. -fn fold_one(func: &mut IrFunction, global_uses: &HashMap) -> bool { +fn fold_one( + func: &mut IrFunction, + global_uses: &HashMap, + global_consts: &HashMap, +) -> bool { // Build a map of VarId → defining instruction info. // We only care about single-use vars defined by Eqz, Ne(x,0), or Eq(x,0). let mut var_defs: HashMap = HashMap::new(); - // Also build a global constant map for checking if an operand is zero. - let global_consts = build_global_const_map(func); - for block in &func.blocks { - let mut local_consts = global_consts.clone(); for instr in &block.instructions { - if let IrInstr::Const { dest, value } = instr { - local_consts.insert(*dest, *value); - } - if let Some(dest) = instr_dest(instr) { match instr { IrInstr::UnOp { @@ -55,9 +52,9 @@ fn fold_one(func: &mut IrFunction, global_uses: &HashMap) -> bool rhs, .. } => { - if is_zero(rhs, &local_consts) { + if is_zero(*rhs, global_consts) { var_defs.insert(dest, VarDef::NeZero(*lhs)); - } else if is_zero(lhs, &local_consts) { + } else if is_zero(*lhs, global_consts) { var_defs.insert(dest, VarDef::NeZero(*rhs)); } } @@ -67,9 +64,9 @@ fn fold_one(func: &mut IrFunction, global_uses: &HashMap) -> bool rhs, .. } => { - if is_zero(rhs, &local_consts) { + if is_zero(*rhs, global_consts) { var_defs.insert(dest, VarDef::EqZero(*lhs)); - } else if is_zero(lhs, &local_consts) { + } else if is_zero(*lhs, global_consts) { var_defs.insert(dest, VarDef::EqZero(*rhs)); } } @@ -133,34 +130,6 @@ enum VarDef { EqZero(VarId), } -fn is_zero(var: &VarId, consts: &HashMap) -> bool { - matches!( - consts.get(var), - Some(IrValue::I32(0)) | Some(IrValue::I64(0)) - ) -} - -fn build_global_const_map(func: &IrFunction) -> HashMap { - let mut total_defs: HashMap = HashMap::new(); - let mut const_defs: HashMap = HashMap::new(); - - for block in &func.blocks { - for instr in &block.instructions { - if let Some(dest) = instr_dest(instr) { - *total_defs.entry(dest).or_insert(0) += 1; - if let IrInstr::Const { dest, value } = instr { - const_defs.insert(*dest, *value); - } - } - } - } - - const_defs - .into_iter() - .filter(|(v, _)| total_defs.get(v).copied().unwrap_or(0) == 1) - .collect() -} - // ── Tests ───────────────────────────────────────────────────────────────────── #[cfg(test)] diff --git a/crates/herkos-core/src/optimizer/utils.rs b/crates/herkos-core/src/optimizer/utils.rs index dcd385f..52c82a7 100644 --- a/crates/herkos-core/src/optimizer/utils.rs +++ b/crates/herkos-core/src/optimizer/utils.rs @@ -495,6 +495,14 @@ pub fn rewrite_terminator_target(term: &mut IrTerminator, old: BlockId, new: Blo } } +/// Returns `true` if `var` is known to be zero according to `consts`. +pub fn is_zero(var: VarId, consts: &HashMap) -> bool { + matches!( + consts.get(&var), + Some(IrValue::I32(0)) | Some(IrValue::I64(0)) + ) +} + /// Variables with exactly one definition across the function that is a `Const` /// instruction. These can be treated as constants in any block that uses them. pub fn build_global_const_map(func: &IrFunction) -> HashMap {