Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 334 additions & 0 deletions crates/herkos-core/src/optimizer/branch_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
//! Branch condition folding.
//!
//! Simplifies `BranchIf` terminators by looking at the instruction that
//! defines the condition variable:
//!
//! - `Eqz(x)` as condition → swap branch targets, use `x` directly
//! - `Ne(x, 0)` as condition → use `x` directly
//! - `Eq(x, 0)` as condition → swap branch targets, use `x` directly
//!
//! After substitution, the defining instruction becomes dead (single use was
//! the branch) and is cleaned up by `dead_instrs`.

use super::utils::{build_global_const_map, build_global_use_count, instr_dest, is_zero};
use crate::ir::{BinOp, IrFunction, IrInstr, IrTerminator, IrValue, UnOp, VarId};
use std::collections::HashMap;

pub fn eliminate(func: &mut IrFunction) {
loop {
let global_uses = build_global_use_count(func);
let global_consts = build_global_const_map(func);
if !fold_one(func, &global_uses, &global_consts) {
break;
}
}
}

/// Attempt a single branch fold across the function. Returns `true` if a
/// change was made.
fn fold_one(
func: &mut IrFunction,
global_uses: &HashMap<VarId, usize>,
global_consts: &HashMap<VarId, IrValue>,
) -> bool {
// Build a map of VarId → defining instruction info.
// We only care about single-use vars defined by Eqz, Ne(x,0), or Eq(x,0).
let mut var_defs: HashMap<VarId, VarDef> = HashMap::new();

for block in &func.blocks {
for instr in &block.instructions {
if let Some(dest) = instr_dest(instr) {
match instr {
IrInstr::UnOp {
op: UnOp::I32Eqz | UnOp::I64Eqz,
operand,
..
} => {
var_defs.insert(dest, VarDef::Eqz(*operand));
}
IrInstr::BinOp {
op: BinOp::I32Ne | BinOp::I64Ne,
lhs,
rhs,
..
} => {
if is_zero(*rhs, global_consts) {
var_defs.insert(dest, VarDef::NeZero(*lhs));
} else if is_zero(*lhs, global_consts) {
var_defs.insert(dest, VarDef::NeZero(*rhs));
}
}
IrInstr::BinOp {
op: BinOp::I32Eq | BinOp::I64Eq,
lhs,
rhs,
..
} => {
if is_zero(*rhs, global_consts) {
var_defs.insert(dest, VarDef::EqZero(*lhs));
} else if is_zero(*lhs, global_consts) {
var_defs.insert(dest, VarDef::EqZero(*rhs));
}
}
_ => {}
}
}
}
}

// Now scan terminators for BranchIf with a foldable condition.
for block in &mut func.blocks {
let condition = match &block.terminator {
IrTerminator::BranchIf { condition, .. } => *condition,
_ => continue,
};

// Only fold if the condition has exactly one use (the BranchIf).
if global_uses.get(&condition).copied().unwrap_or(0) != 1 {
continue;
}

let def = match var_defs.get(&condition) {
Some(d) => d,
None => continue,
};

match def {
VarDef::Eqz(inner) | VarDef::EqZero(inner) => {
// eqz(x) != 0 ≡ x == 0, so swap targets and use x
if let IrTerminator::BranchIf {
condition: cond,
if_true,
if_false,
} = &mut block.terminator
{
*cond = *inner;
std::mem::swap(if_true, if_false);
}
return true;
}
VarDef::NeZero(inner) => {
// ne(x, 0) != 0 ≡ x != 0, so just use x
if let IrTerminator::BranchIf {
condition: cond, ..
} = &mut block.terminator
{
*cond = *inner;
}
return true;
}
}
}

false
}

#[derive(Clone, Copy)]
enum VarDef {
Eqz(VarId),
NeZero(VarId),
EqZero(VarId),
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BlockId, IrBlock, TypeIdx};

fn make_func(blocks: Vec<IrBlock>) -> IrFunction {
IrFunction {
params: vec![],
locals: vec![],
blocks,
entry_block: BlockId(0),
return_type: None,
type_idx: TypeIdx::new(0),
}
}

#[test]
fn eqz_swaps_targets() {
// v1 = Eqz(v0); BranchIf(v1, B1, B2) → BranchIf(v0, B2, B1)
let mut func = make_func(vec![IrBlock {
id: BlockId(0),
instructions: vec![IrInstr::UnOp {
dest: VarId(1),
op: UnOp::I32Eqz,
operand: VarId(0),
}],
terminator: IrTerminator::BranchIf {
condition: VarId(1),
if_true: BlockId(1),
if_false: BlockId(2),
},
}]);
eliminate(&mut func);
match &func.blocks[0].terminator {
IrTerminator::BranchIf {
condition,
if_true,
if_false,
} => {
assert_eq!(*condition, VarId(0));
assert_eq!(*if_true, BlockId(2), "targets should be swapped");
assert_eq!(*if_false, BlockId(1));
}
other => panic!("expected BranchIf, got {other:?}"),
}
}

#[test]
fn ne_zero_simplifies() {
// v1 = 0; v2 = Ne(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B1, B2)
let mut func = make_func(vec![IrBlock {
id: BlockId(0),
instructions: vec![
IrInstr::Const {
dest: VarId(1),
value: IrValue::I32(0),
},
IrInstr::BinOp {
dest: VarId(2),
op: BinOp::I32Ne,
lhs: VarId(0),
rhs: VarId(1),
},
],
terminator: IrTerminator::BranchIf {
condition: VarId(2),
if_true: BlockId(1),
if_false: BlockId(2),
},
}]);
eliminate(&mut func);
match &func.blocks[0].terminator {
IrTerminator::BranchIf {
condition,
if_true,
if_false,
} => {
assert_eq!(*condition, VarId(0));
assert_eq!(*if_true, BlockId(1), "targets should NOT be swapped");
assert_eq!(*if_false, BlockId(2));
}
other => panic!("expected BranchIf, got {other:?}"),
}
}

#[test]
fn eq_zero_swaps() {
// v1 = 0; v2 = Eq(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B2, B1)
let mut func = make_func(vec![IrBlock {
id: BlockId(0),
instructions: vec![
IrInstr::Const {
dest: VarId(1),
value: IrValue::I32(0),
},
IrInstr::BinOp {
dest: VarId(2),
op: BinOp::I32Eq,
lhs: VarId(0),
rhs: VarId(1),
},
],
terminator: IrTerminator::BranchIf {
condition: VarId(2),
if_true: BlockId(1),
if_false: BlockId(2),
},
}]);
eliminate(&mut func);
match &func.blocks[0].terminator {
IrTerminator::BranchIf {
condition,
if_true,
if_false,
} => {
assert_eq!(*condition, VarId(0));
assert_eq!(*if_true, BlockId(2), "targets should be swapped");
assert_eq!(*if_false, BlockId(1));
}
other => panic!("expected BranchIf, got {other:?}"),
}
}

#[test]
fn multi_use_not_folded() {
// v1 = Eqz(v0); use(v1) elsewhere → don't fold
let mut func = make_func(vec![
IrBlock {
id: BlockId(0),
instructions: vec![IrInstr::UnOp {
dest: VarId(1),
op: UnOp::I32Eqz,
operand: VarId(0),
}],
terminator: IrTerminator::BranchIf {
condition: VarId(1),
if_true: BlockId(1),
if_false: BlockId(2),
},
},
IrBlock {
id: BlockId(1),
instructions: vec![],
terminator: IrTerminator::Return {
value: Some(VarId(1)), // second use of v1
},
},
IrBlock {
id: BlockId(2),
instructions: vec![],
terminator: IrTerminator::Return { value: None },
},
]);
eliminate(&mut func);
// Should NOT fold — v1 has 2 uses
match &func.blocks[0].terminator {
IrTerminator::BranchIf { condition, .. } => {
assert_eq!(*condition, VarId(1), "should not have been folded");
}
other => panic!("expected BranchIf, got {other:?}"),
}
}

#[test]
fn cross_block_zero_const() {
// B0: v1 = 0; Jump(B1)
// B1: v2 = Ne(v0, v1); BranchIf(v2, B2, B3) → BranchIf(v0, B2, B3)
let mut func = make_func(vec![
IrBlock {
id: BlockId(0),
instructions: vec![IrInstr::Const {
dest: VarId(1),
value: IrValue::I32(0),
}],
terminator: IrTerminator::Jump { target: BlockId(1) },
},
IrBlock {
id: BlockId(1),
instructions: vec![IrInstr::BinOp {
dest: VarId(2),
op: BinOp::I32Ne,
lhs: VarId(0),
rhs: VarId(1),
}],
terminator: IrTerminator::BranchIf {
condition: VarId(2),
if_true: BlockId(2),
if_false: BlockId(3),
},
},
]);
eliminate(&mut func);
match &func.blocks[1].terminator {
IrTerminator::BranchIf { condition, .. } => {
assert_eq!(*condition, VarId(0));
}
other => panic!("expected BranchIf, got {other:?}"),
}
}
}
18 changes: 7 additions & 11 deletions crates/herkos-core/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod copy_prop;
mod dead_blocks;

// ── Post-lowering passes ─────────────────────────────────────────────────────
mod branch_fold;
mod dead_instrs;
mod empty_blocks;
mod merge_blocks;
Expand All @@ -46,32 +47,27 @@ pub fn optimize_ir(module_info: ModuleInfo, do_opt: bool) -> Result<ModuleInfo>

/// Optimizes the lowered IR after phi nodes have been eliminated.
///
/// Passes here operate on [`LoweredModuleInfo`] where all `IrInstr::Phi`
/// nodes have been replaced by `IrInstr::Assign` in predecessor blocks.
///
/// ## Structural and copy passes
///
/// Run multiple iterations of structural cleanup + copy propagation.
/// Runs post-lowering structural passes and branch condition folding.
/// dead_instrs may leave empty blocks, which empty_blocks and merge_blocks then
/// eliminate, potentially exposing new dead instructions. copy_prop forwards the
/// assignments that lower_phis inserted. We repeat until reaching a fixed point
/// (typically 2 iterations).
/// eliminate, potentially exposing new dead instructions. branch_fold simplifies
/// `BranchIf` terminators whose condition is a known comparison. We repeat until
/// reaching a fixed point (typically 2 iterations).
pub fn optimize_lowered_ir(
module_info: LoweredModuleInfo,
do_opt: bool,
) -> Result<LoweredModuleInfo> {
let mut module_info = module_info;
if do_opt {
for func in &mut module_info.ir_functions {
// Two passes: dead_instrs may create empty blocks, and copy_prop
// may reveal new dead instrs.
for _ in 0..2 {
empty_blocks::eliminate(func);
dead_blocks::eliminate(func)?;
merge_blocks::eliminate(func);
dead_blocks::eliminate(func)?;
copy_prop::eliminate(func);
dead_instrs::eliminate(func);
branch_fold::eliminate(func);
dead_instrs::eliminate(func);
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions crates/herkos-core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,14 @@ pub fn rewrite_terminator_target(term: &mut IrTerminator, old: BlockId, new: Blo
}
}

/// Returns `true` if `var` is known to be zero according to `consts`.
pub fn is_zero(var: VarId, consts: &HashMap<VarId, IrValue>) -> bool {
matches!(
consts.get(&var),
Some(IrValue::I32(0)) | Some(IrValue::I64(0))
)
}

/// Variables with exactly one definition across the function that is a `Const`
/// instruction. These can be treated as constants in any block that uses them.
pub fn build_global_const_map(func: &IrFunction) -> HashMap<VarId, IrValue> {
Expand Down
Loading