From bf541960cd651da0b9f056502a7294df513445b2 Mon Sep 17 00:00:00 2001 From: bench Date: Mon, 23 Mar 2026 09:07:01 +0000 Subject: [PATCH 1/4] feat: add redundancy elimination and loop optimization passes Add four post-lowering optimization passes: - local_cse: Local common subexpression elimination within blocks - gvn: Global value numbering across blocks using dominator tree - licm: Loop invariant code motion - branch_fold: Branch condition simplification Also includes shared optimizer utilities and refactored structural passes. WIP: Tests and full integration in progress. Co-Authored-By: Claude Haiku 4.5 --- .../herkos-core/src/optimizer/branch_fold.rs | 365 +++++ .../herkos-core/src/optimizer/dead_blocks.rs | 23 +- .../herkos-core/src/optimizer/dead_instrs.rs | 361 +++++ .../herkos-core/src/optimizer/empty_blocks.rs | 332 +++++ crates/herkos-core/src/optimizer/gvn.rs | 618 ++++++++ crates/herkos-core/src/optimizer/licm.rs | 1306 +++++++++++++++++ crates/herkos-core/src/optimizer/local_cse.rs | 574 ++++++++ .../herkos-core/src/optimizer/merge_blocks.rs | 382 +++++ crates/herkos-core/src/optimizer/mod.rs | 45 +- crates/herkos-core/src/optimizer/utils.rs | 658 +++++++++ 10 files changed, 4636 insertions(+), 28 deletions(-) create mode 100644 crates/herkos-core/src/optimizer/branch_fold.rs create mode 100644 crates/herkos-core/src/optimizer/dead_instrs.rs create mode 100644 crates/herkos-core/src/optimizer/empty_blocks.rs create mode 100644 crates/herkos-core/src/optimizer/gvn.rs create mode 100644 crates/herkos-core/src/optimizer/licm.rs create mode 100644 crates/herkos-core/src/optimizer/local_cse.rs create mode 100644 crates/herkos-core/src/optimizer/merge_blocks.rs create mode 100644 crates/herkos-core/src/optimizer/utils.rs diff --git a/crates/herkos-core/src/optimizer/branch_fold.rs b/crates/herkos-core/src/optimizer/branch_fold.rs new file mode 100644 index 0000000..3a7f978 --- /dev/null +++ b/crates/herkos-core/src/optimizer/branch_fold.rs @@ -0,0 +1,365 @@ +//! Branch condition folding. +//! +//! Simplifies `BranchIf` terminators by looking at the instruction that +//! defines the condition variable: +//! +//! - `Eqz(x)` as condition → swap branch targets, use `x` directly +//! - `Ne(x, 0)` as condition → use `x` directly +//! - `Eq(x, 0)` as condition → swap branch targets, use `x` directly +//! +//! After substitution, the defining instruction becomes dead (single use was +//! the branch) and is cleaned up by `dead_instrs`. + +use super::utils::{build_global_use_count, instr_dest}; +use crate::ir::{BinOp, IrFunction, IrInstr, IrTerminator, IrValue, UnOp, VarId}; +use std::collections::HashMap; + +pub fn eliminate(func: &mut IrFunction) { + loop { + let global_uses = build_global_use_count(func); + if !fold_one(func, &global_uses) { + break; + } + } +} + +/// Attempt a single branch fold across the function. Returns `true` if a +/// change was made. +fn fold_one(func: &mut IrFunction, global_uses: &HashMap) -> bool { + // Build a map of VarId → defining instruction info. + // We only care about single-use vars defined by Eqz, Ne(x,0), or Eq(x,0). + let mut var_defs: HashMap = HashMap::new(); + + // Also build a global constant map for checking if an operand is zero. + let global_consts = build_global_const_map(func); + + for block in &func.blocks { + let mut local_consts = global_consts.clone(); + for instr in &block.instructions { + if let IrInstr::Const { dest, value } = instr { + local_consts.insert(*dest, *value); + } + + if let Some(dest) = instr_dest(instr) { + match instr { + IrInstr::UnOp { + op: UnOp::I32Eqz | UnOp::I64Eqz, + operand, + .. + } => { + var_defs.insert(dest, VarDef::Eqz(*operand)); + } + IrInstr::BinOp { + op: BinOp::I32Ne | BinOp::I64Ne, + lhs, + rhs, + .. + } => { + if is_zero(rhs, &local_consts) { + var_defs.insert(dest, VarDef::NeZero(*lhs)); + } else if is_zero(lhs, &local_consts) { + var_defs.insert(dest, VarDef::NeZero(*rhs)); + } + } + IrInstr::BinOp { + op: BinOp::I32Eq | BinOp::I64Eq, + lhs, + rhs, + .. + } => { + if is_zero(rhs, &local_consts) { + var_defs.insert(dest, VarDef::EqZero(*lhs)); + } else if is_zero(lhs, &local_consts) { + var_defs.insert(dest, VarDef::EqZero(*rhs)); + } + } + _ => {} + } + } + } + } + + // Now scan terminators for BranchIf with a foldable condition. + for block in &mut func.blocks { + let condition = match &block.terminator { + IrTerminator::BranchIf { condition, .. } => *condition, + _ => continue, + }; + + // Only fold if the condition has exactly one use (the BranchIf). + if global_uses.get(&condition).copied().unwrap_or(0) != 1 { + continue; + } + + let def = match var_defs.get(&condition) { + Some(d) => d, + None => continue, + }; + + match def { + VarDef::Eqz(inner) | VarDef::EqZero(inner) => { + // eqz(x) != 0 ≡ x == 0, so swap targets and use x + if let IrTerminator::BranchIf { + condition: cond, + if_true, + if_false, + } = &mut block.terminator + { + *cond = *inner; + std::mem::swap(if_true, if_false); + } + return true; + } + VarDef::NeZero(inner) => { + // ne(x, 0) != 0 ≡ x != 0, so just use x + if let IrTerminator::BranchIf { + condition: cond, .. + } = &mut block.terminator + { + *cond = *inner; + } + return true; + } + } + } + + false +} + +#[derive(Clone, Copy)] +enum VarDef { + Eqz(VarId), + NeZero(VarId), + EqZero(VarId), +} + +fn is_zero(var: &VarId, consts: &HashMap) -> bool { + matches!( + consts.get(var), + Some(IrValue::I32(0)) | Some(IrValue::I64(0)) + ) +} + +fn build_global_const_map(func: &IrFunction) -> HashMap { + let mut total_defs: HashMap = HashMap::new(); + let mut const_defs: HashMap = HashMap::new(); + + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *total_defs.entry(dest).or_insert(0) += 1; + if let IrInstr::Const { dest, value } = instr { + const_defs.insert(*dest, *value); + } + } + } + } + + const_defs + .into_iter() + .filter(|(v, _)| total_defs.get(v).copied().unwrap_or(0) == 1) + .collect() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, TypeIdx}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + #[test] + fn eqz_swaps_targets() { + // v1 = Eqz(v0); BranchIf(v1, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn ne_zero_simplifies() { + // v1 = 0; v2 = Ne(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B1, B2) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(1), "targets should NOT be swapped"); + assert_eq!(*if_false, BlockId(2)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn eq_zero_swaps() { + // v1 = 0; v2 = Eq(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Eq, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn multi_use_not_folded() { + // v1 = Eqz(v0); use(v1) elsewhere → don't fold + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(1)), // second use of v1 + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // Should NOT fold — v1 has 2 uses + match &func.blocks[0].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(1), "should not have been folded"); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn cross_block_zero_const() { + // B0: v1 = 0; Jump(B1) + // B1: v2 = Ne(v0, v1); BranchIf(v2, B2, B3) → BranchIf(v0, B2, B3) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + ]); + eliminate(&mut func); + match &func.blocks[1].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(0)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } +} diff --git a/crates/herkos-core/src/optimizer/dead_blocks.rs b/crates/herkos-core/src/optimizer/dead_blocks.rs index bbe6b01..085283a 100644 --- a/crates/herkos-core/src/optimizer/dead_blocks.rs +++ b/crates/herkos-core/src/optimizer/dead_blocks.rs @@ -4,28 +4,11 @@ //! arise naturally during IR translation when code follows an `unreachable` or //! `return` instruction inside a Wasm structured control flow construct. -use crate::ir::{BlockId, IrBlock, IrFunction, IrTerminator}; +use super::utils::terminator_successors; +use crate::ir::{BlockId, IrBlock, IrFunction}; use anyhow::{bail, Result}; use std::collections::{HashMap, HashSet}; -/// Returns the successor block IDs for a terminator. -fn terminator_successors(term: &IrTerminator) -> Vec { - match term { - IrTerminator::Return { .. } | IrTerminator::Unreachable => vec![], - IrTerminator::Jump { target } => vec![*target], - IrTerminator::BranchIf { - if_true, if_false, .. - } => vec![*if_true, *if_false], - IrTerminator::BranchTable { - targets, default, .. - } => targets - .iter() - .chain(std::iter::once(default)) - .copied() - .collect(), - } -} - /// Computes the set of block IDs reachable from the entry block via BFS. fn reachable_blocks(func: &IrFunction) -> Result> { // Index blocks by ID for O(1) lookup during traversal. @@ -61,7 +44,7 @@ pub fn eliminate(func: &mut IrFunction) -> Result<()> { #[cfg(test)] mod tests { use super::*; - use crate::ir::{IrInstr, IrValue, TypeIdx, VarId, WasmType}; + use crate::ir::{IrInstr, IrTerminator, IrValue, TypeIdx, VarId, WasmType}; /// Build a minimal `IrFunction` with the given blocks. /// Entry block is always `BlockId(0)`. diff --git a/crates/herkos-core/src/optimizer/dead_instrs.rs b/crates/herkos-core/src/optimizer/dead_instrs.rs new file mode 100644 index 0000000..6763ae7 --- /dev/null +++ b/crates/herkos-core/src/optimizer/dead_instrs.rs @@ -0,0 +1,361 @@ +//! Dead instruction elimination. +//! +//! Removes instructions whose destination `VarId` has zero uses across the +//! entire function and whose operation is side-effect-free. +//! +//! ## Algorithm +//! +//! 1. Build the global use-count map (`VarId → number of reads`). +//! 2. For each instruction that produces a value (`instr_dest` returns `Some`): +//! if the use count is zero **and** the instruction is side-effect-free, +//! mark it for removal. +//! 3. Remove all marked instructions. +//! 4. Repeat to fixpoint — removing an instruction may make its operands' +//! definitions unused. +//! 5. Prune dead locals from `IrFunction::locals`. + +use super::utils::{build_global_use_count, instr_dest, is_side_effect_free, prune_dead_locals}; +use crate::ir::IrFunction; + +/// Run dead instruction elimination to fixpoint, then prune dead locals. +pub fn eliminate(func: &mut IrFunction) { + loop { + let uses = build_global_use_count(func); + let mut changed = false; + + for block in &mut func.blocks { + block.instructions.retain(|instr| { + if let Some(dest) = instr_dest(instr) { + if uses.get(&dest).copied().unwrap_or(0) == 0 && is_side_effect_free(instr) { + changed = true; + return false; // remove + } + } + true // keep + }); + } + + if !changed { + break; + } + } + + prune_dead_locals(func); +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + BinOp, BlockId, GlobalIdx, IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, + MemoryAccessWidth, TypeIdx, VarId, WasmType, + }; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + fn make_func_with_locals(blocks: Vec, locals: Vec<(VarId, WasmType)>) -> IrFunction { + IrFunction { + params: vec![], + locals, + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + fn single_block(instrs: Vec, term: IrTerminator) -> Vec { + vec![IrBlock { + id: BlockId(0), + instructions: instrs, + terminator: term, + }] + } + + fn ret_none() -> IrTerminator { + IrTerminator::Return { value: None } + } + + // ── Basic: unused side-effect-free instruction is removed ───────────── + + #[test] + fn unused_const_removed() { + let mut func = make_func(single_block( + vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + #[test] + fn unused_binop_removed() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + ret_none(), + )); + eliminate(&mut func); + // v2 unused → removed; then v0, v1 become unused → removed + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + // ── Used instruction is kept ───────────────────────────────────────── + + #[test] + fn used_const_kept() { + let mut func = make_func(single_block( + vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + IrTerminator::Return { + value: Some(VarId(0)), + }, + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + } + + // ── Side-effectful instructions are kept even when unused ───────────── + + #[test] + fn unused_load_kept() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }, + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + ], + ret_none(), + )); + eliminate(&mut func); + // Load may trap → kept; v0 is used by Load → kept + assert_eq!(func.blocks[0].instructions.len(), 2); + } + + #[test] + fn store_kept() { + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(99), + }, + IrInstr::Store { + ty: WasmType::I32, + addr: VarId(0), + value: VarId(1), + offset: 0, + width: MemoryAccessWidth::Full, + }, + ], + ret_none(), + )); + eliminate(&mut func); + // Store has side effects → kept; v0, v1 used by Store → kept + assert_eq!(func.blocks[0].instructions.len(), 3); + } + + // ── Fixpoint: cascading removal ────────────────────────────────────── + + #[test] + fn fixpoint_cascading_removal() { + // v0 = Const(1) + // v1 = Const(2) + // v2 = BinOp(v0, v1) ← only use of v0, v1 + // v3 = BinOp(v2, v2) ← only use of v2 + // Return(None) ← v3 unused + // + // Round 1: v3 unused → remove v3's BinOp + // Round 2: v2 unused → remove v2's BinOp + // Round 3: v0, v1 unused → remove both Consts + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Mul, + lhs: VarId(2), + rhs: VarId(2), + }, + ], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + // ── Mixed: some dead, some live ────────────────────────────────────── + + #[test] + fn mixed_dead_and_live() { + // v0 = Const(1) ← used by Return + // v1 = Const(2) ← unused → dead + // v2 = BinOp(v1, v1) ← unused → dead + let mut func = make_func(single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(1), + }, + ], + IrTerminator::Return { + value: Some(VarId(0)), + }, + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + match &func.blocks[0].instructions[0] { + IrInstr::Const { + dest, + value: IrValue::I32(1), + } => assert_eq!(*dest, VarId(0)), + other => panic!("expected Const(v0, 1), got {other:?}"), + } + } + + // ── Dead locals are pruned ─────────────────────────────────────────── + + #[test] + fn dead_locals_pruned() { + let mut func = make_func_with_locals( + single_block( + vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(2), + }, + ], + IrTerminator::Return { + value: Some(VarId(0)), + }, + ), + vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)], + ); + eliminate(&mut func); + // v1 is dead → removed from instructions and locals + assert!(!func.locals.iter().any(|(v, _)| *v == VarId(1))); + assert!(func.locals.iter().any(|(v, _)| *v == VarId(0))); + } + + // ── Multi-block: dead in one, live in another ──────────────────────── + + #[test] + fn multi_block_cross_reference_kept() { + // Block 0: v0 = Const(1); Jump(B1) + // Block 1: Return(v0) + // v0 is used in B1 → must NOT be removed from B0 + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(0)), + }, + }, + ]); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 1); + } + + // ── GlobalGet (side-effect-free) is removed when unused ────────────── + + #[test] + fn unused_global_get_removed() { + let mut func = make_func(single_block( + vec![IrInstr::GlobalGet { + dest: VarId(0), + index: GlobalIdx::new(0), + }], + ret_none(), + )); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } + + // ── No-op: empty function ──────────────────────────────────────────── + + #[test] + fn empty_function_unchanged() { + let mut func = make_func(single_block(vec![], ret_none())); + eliminate(&mut func); + assert_eq!(func.blocks[0].instructions.len(), 0); + } +} diff --git a/crates/herkos-core/src/optimizer/empty_blocks.rs b/crates/herkos-core/src/optimizer/empty_blocks.rs new file mode 100644 index 0000000..ae0b965 --- /dev/null +++ b/crates/herkos-core/src/optimizer/empty_blocks.rs @@ -0,0 +1,332 @@ +//! Empty block / passthrough elimination. +//! +//! A passthrough block contains no instructions and ends with an unconditional +//! `Jump`. All references to such a block can be replaced by references to its +//! ultimate target, eliminating the block entirely. +//! +//! Example (from the fibo transpilation): +//! B5: {} → Jump(B6) +//! B6: {} → Jump(B7) +//! +//! After this pass B4's branch-false target is rewritten from B5 to B7, and +//! B5/B6 become unreferenced dead blocks, removed by `dead_blocks::eliminate` +//! in the next pass. + +use crate::ir::{BlockId, IrFunction, IrTerminator}; +use std::collections::HashMap; + +/// Replace every reference to a passthrough block with its ultimate target. +/// +/// After this call, all passthrough blocks are unreferenced and will be +/// removed by the subsequent `dead_blocks::eliminate` pass. +pub fn eliminate(func: &mut IrFunction) { + // ── Step 1: Build the raw forwarding map ──────────────────────────── + // A block is a passthrough if it has no instructions and its terminator + // is an unconditional Jump. + let mut forward: HashMap = HashMap::new(); + for block in &func.blocks { + if block.instructions.is_empty() { + if let IrTerminator::Jump { target } = block.terminator { + forward.insert(block.id, target); + } + } + } + + if forward.is_empty() { + return; + } + + // ── Step 2: Resolve chains, cycle-safe ────────────────────────────── + // Collapse A → B → C chains into A → C. + // Bound hop count to func.blocks.len() to handle cycles (e.g. A→B→A). + let max_hops = func.blocks.len(); + let resolved: HashMap = forward + .keys() + .copied() + .map(|start| { + let mut cur = start; + for _ in 0..max_hops { + match forward.get(&cur) { + Some(&next) => cur = next, + None => break, + } + } + (start, cur) + }) + .collect(); + + // ── Step 3: Rewrite all terminator targets ─────────────────────────── + let fwd = |id: BlockId| resolved.get(&id).copied().unwrap_or(id); + + for block in &mut func.blocks { + match &mut block.terminator { + IrTerminator::Jump { target } => { + *target = fwd(*target); + } + IrTerminator::BranchIf { + if_true, if_false, .. + } => { + *if_true = fwd(*if_true); + *if_false = fwd(*if_false); + } + IrTerminator::BranchTable { + targets, default, .. + } => { + for t in targets.iter_mut() { + *t = fwd(*t); + } + *default = fwd(*default); + } + IrTerminator::Return { .. } | IrTerminator::Unreachable => {} + } + } + // Passthrough blocks are now unreferenced; dead_blocks::eliminate will + // remove them in the next pass. +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, TypeIdx, VarId}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + fn jump(id: u32, target: u32) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions: vec![], + terminator: IrTerminator::Jump { + target: BlockId(target), + }, + } + } + + fn ret(id: u32) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + } + } + + fn branch(id: u32, cond: u32, if_true: u32, if_false: u32) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(cond), + if_true: BlockId(if_true), + if_false: BlockId(if_false), + }, + } + } + + fn target_of(func: &IrFunction, id: u32) -> Option { + func.blocks + .iter() + .find(|b| b.id == BlockId(id)) + .and_then(|b| match b.terminator { + IrTerminator::Jump { target } => Some(target), + _ => None, + }) + } + + fn branch_targets(func: &IrFunction, id: u32) -> Option<(BlockId, BlockId)> { + func.blocks + .iter() + .find(|b| b.id == BlockId(id)) + .and_then(|b| match b.terminator { + IrTerminator::BranchIf { + if_true, if_false, .. + } => Some((if_true, if_false)), + _ => None, + }) + } + + // ── Basic cases ────────────────────────────────────────────────────── + + #[test] + fn no_passthrough_unchanged() { + // B0: instr → Jump(B1), B1: Return — no passthrough, nothing changes + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + ret(1), + ]); + eliminate(&mut func); + assert_eq!(target_of(&func, 0), Some(BlockId(1))); + assert_eq!(func.blocks.len(), 2); + } + + #[test] + fn single_passthrough_redirected() { + // B0 → B1(pass) → B2: B0's target becomes B2 + let mut func = make_func(vec![jump(0, 1), jump(1, 2), ret(2)]); + eliminate(&mut func); + assert_eq!(target_of(&func, 0), Some(BlockId(2))); + } + + #[test] + fn chain_collapsed() { + // B0 → B1(pass) → B2(pass) → B3: B0's target becomes B3 + let mut func = make_func(vec![jump(0, 1), jump(1, 2), jump(2, 3), ret(3)]); + eliminate(&mut func); + assert_eq!(target_of(&func, 0), Some(BlockId(3))); + // B1 should also forward to B3 + assert_eq!(target_of(&func, 1), Some(BlockId(3))); + } + + // ── BranchIf ──────────────────────────────────────────────────────── + + #[test] + fn branch_if_both_arms_redirected() { + // B0: BranchIf(true→B1(pass)→B3, false→B2(pass)→B4) + let mut func = make_func(vec![ + branch(0, 0, 1, 2), + jump(1, 3), + jump(2, 4), + ret(3), + ret(4), + ]); + eliminate(&mut func); + let (t, f) = branch_targets(&func, 0).unwrap(); + assert_eq!(t, BlockId(3)); + assert_eq!(f, BlockId(4)); + } + + #[test] + fn branch_if_one_arm_redirected() { + // B0: BranchIf(true→B1(non-pass), false→B2(pass)→B3) + let mut func = make_func(vec![ + branch(0, 0, 1, 2), + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Return { value: None }, + }, + jump(2, 3), + ret(3), + ]); + eliminate(&mut func); + let (t, f) = branch_targets(&func, 0).unwrap(); + assert_eq!(t, BlockId(1)); // unchanged + assert_eq!(f, BlockId(3)); // forwarded + } + + // ── BranchTable ────────────────────────────────────────────────────── + + #[test] + fn branch_table_redirected() { + // B0: BranchTable(targets:[B1(pass)→B3, B2(pass)→B4], default:B5) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchTable { + index: VarId(0), + targets: vec![BlockId(1), BlockId(2)], + default: BlockId(5), + }, + }, + jump(1, 3), + jump(2, 4), + ret(3), + ret(4), + ret(5), + ]); + eliminate(&mut func); + let b = func.blocks.iter().find(|b| b.id == BlockId(0)).unwrap(); + match &b.terminator { + IrTerminator::BranchTable { + targets, default, .. + } => { + assert_eq!(targets[0], BlockId(3)); + assert_eq!(targets[1], BlockId(4)); + assert_eq!(*default, BlockId(5)); // non-passthrough, unchanged + } + _ => panic!("expected BranchTable"), + } + } + + // ── Edge cases ─────────────────────────────────────────────────────── + + #[test] + fn cycle_safe() { + // B0 → B1(pass) → B2(pass) → B1 (cycle) + // Should not infinite loop; B0 ends up pointing somewhere in the cycle + let mut func = make_func(vec![jump(0, 1), jump(1, 2), jump(2, 1)]); + // Must complete without hanging; exact target is unspecified for cycles + eliminate(&mut func); + } + + #[test] + fn entry_passthrough_not_removed() { + // Entry block B0 is itself a passthrough: B0(pass) → B1 → Return + // After pass B0's jump stays (it's a passthrough of a passthrough pointing at B1), + // dead_blocks won't remove B0 (it starts BFS from entry). + let mut func = make_func(vec![jump(0, 1), ret(1)]); + eliminate(&mut func); + // B0 is a passthrough pointing to B1; resolve(B0)=B1 but nobody *jumps to* B0, + // so B0's own terminator remains Jump(B1) (forwarded to itself, i.e. B1). + assert_eq!(target_of(&func, 0), Some(BlockId(1))); + assert_eq!(func.blocks.len(), 2); // dead_blocks not called here, both still present + } + + // ── Realistic fibo pattern ──────────────────────────────────────────── + + #[test] + fn fibo_pattern() { + // Mirrors the B3/B4/B5/B6/B7 structure from func_7 (release build): + // B3: BranchIf(cond→B7, else→B4) + // B4: BranchIf(cond→B3, else→B5) + // B5: {} → Jump(B6) ← passthrough + // B6: {} → Jump(B7) ← passthrough + // B7: Return + // + // After eliminate(): + // B4's false-arm should be B7 (not B5) + // B3 and B7 are unchanged + let mut func = make_func(vec![ + branch(3, 0, 7, 4), + branch(4, 1, 3, 5), + jump(5, 6), // passthrough + jump(6, 7), // passthrough + ret(7), + ]); + func.entry_block = BlockId(3); + + eliminate(&mut func); + + // B4's false-arm: was B5, must now be B7 + let (true_arm, false_arm) = branch_targets(&func, 4).unwrap(); + assert_eq!(true_arm, BlockId(3)); // back-edge unchanged + assert_eq!(false_arm, BlockId(7)); // forwarded through B5→B6→B7 + + // B3's true-arm was already B7 — still B7 + let (t3, f3) = branch_targets(&func, 3).unwrap(); + assert_eq!(t3, BlockId(7)); + assert_eq!(f3, BlockId(4)); + + // B5 and B6 themselves now point to B7 (resolved) + assert_eq!(target_of(&func, 5), Some(BlockId(7))); + assert_eq!(target_of(&func, 6), Some(BlockId(7))); + } +} diff --git a/crates/herkos-core/src/optimizer/gvn.rs b/crates/herkos-core/src/optimizer/gvn.rs new file mode 100644 index 0000000..f5cd732 --- /dev/null +++ b/crates/herkos-core/src/optimizer/gvn.rs @@ -0,0 +1,618 @@ +//! Global value numbering (GVN) — cross-block CSE using the dominator tree. +//! +//! Extends block-local CSE ([`super::local_cse`]) to work across basic blocks. +//! If block A dominates block B (every path to B passes through A), then any +//! pure computation defined in A with the same value key as one in B can be +//! reused in B instead of recomputing. +//! +//! ## Algorithm +//! +//! 1. Compute the immediate dominator of each block (Cooper/Harvey/Kennedy +//! iterative algorithm) to build the dominator tree. +//! 2. Walk the dominator tree in preorder using a scoped value-number table. +//! On entry to a block, push a new scope; on exit, pop it. +//! 3. For each pure instruction (`Const`, `BinOp`, `UnOp`) in the current +//! block, compute a value key. If the key already exists in any enclosing +//! scope (meaning it was computed in a dominating block), record a +//! replacement: `dest → first_var`. Otherwise insert the key into the +//! current scope. +//! 4. After the walk, rewrite all recorded destinations to +//! `Assign { dest, src: first_var }` and let copy-propagation clean up. +//! +//! **Only pure instructions are eligible.** Loads, calls, and memory ops are +//! never deduplicated (they may trap or have observable side effects). + +use super::utils::{build_predecessors, instr_dest, prune_dead_locals, terminator_successors}; +use crate::ir::{BinOp, BlockId, IrFunction, IrInstr, IrValue, UnOp, VarId}; +use std::collections::{HashMap, HashSet}; + +// ── Value key ──────────────────────────────────────────────────────────────── + +/// Hashable representation of a pure computation for deduplication. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ValueKey { + Const(ConstKey), + BinOp { op: BinOp, lhs: VarId, rhs: VarId }, + UnOp { op: UnOp, operand: VarId }, +} + +/// Bit-level constant key that implements `Eq`/`Hash` correctly for floats. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ConstKey { + I32(i32), + I64(i64), + F32(u32), + F64(u64), +} + +impl From for ConstKey { + fn from(v: IrValue) -> Self { + match v { + IrValue::I32(x) => ConstKey::I32(x), + IrValue::I64(x) => ConstKey::I64(x), + IrValue::F32(x) => ConstKey::F32(x.to_bits()), + IrValue::F64(x) => ConstKey::F64(x.to_bits()), + } + } +} + +fn is_commutative(op: &BinOp) -> bool { + matches!( + op, + BinOp::I32Add + | BinOp::I32Mul + | BinOp::I32And + | BinOp::I32Or + | BinOp::I32Xor + | BinOp::I32Eq + | BinOp::I32Ne + | BinOp::I64Add + | BinOp::I64Mul + | BinOp::I64And + | BinOp::I64Or + | BinOp::I64Xor + | BinOp::I64Eq + | BinOp::I64Ne + | BinOp::F32Add + | BinOp::F32Mul + | BinOp::F32Eq + | BinOp::F32Ne + | BinOp::F64Add + | BinOp::F64Mul + | BinOp::F64Eq + | BinOp::F64Ne + ) +} + +fn binop_key(op: BinOp, lhs: VarId, rhs: VarId) -> ValueKey { + let (lhs, rhs) = if is_commutative(&op) && lhs.0 > rhs.0 { + (rhs, lhs) + } else { + (lhs, rhs) + }; + ValueKey::BinOp { op, lhs, rhs } +} + +// ── Multi-definition detection ─────────────────────────────────────────────── + +/// Build the set of variables defined more than once across the function. +/// +/// After phi lowering the code is no longer in strict SSA form: loop phi +/// variables receive an initial assignment in the pre-loop block and a +/// back-edge update at the end of each iteration. These variables carry +/// different values at different program points, so any BinOp/UnOp that uses +/// them cannot be safely hoisted or deduplicated across blocks. +/// +/// `Const` instructions are always safe (they have no operands). +fn build_multi_def_vars(func: &IrFunction) -> HashSet { + let mut def_count: HashMap = HashMap::new(); + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *def_count.entry(dest).or_insert(0) += 1; + } + } + } + def_count + .into_iter() + .filter(|&(_, count)| count > 1) + .map(|(v, _)| v) + .collect() +} + +// ── Dominator tree ─────────────────────────────────────────────────────────── + +/// Compute the reverse-postorder traversal of the CFG starting from `entry`. +fn compute_rpo(func: &IrFunction) -> Vec { + let block_idx: HashMap = + func.blocks.iter().enumerate().map(|(i, b)| (b.id, i)).collect(); + + let mut visited = vec![false; func.blocks.len()]; + let mut postorder = Vec::with_capacity(func.blocks.len()); + + dfs_postorder(func, func.entry_block, &block_idx, &mut visited, &mut postorder); + + postorder.reverse(); + postorder +} + +fn dfs_postorder( + func: &IrFunction, + block_id: BlockId, + block_idx: &HashMap, + visited: &mut Vec, + postorder: &mut Vec, +) { + let idx = match block_idx.get(&block_id) { + Some(&i) => i, + None => return, + }; + if visited[idx] { + return; + } + visited[idx] = true; + + for succ in terminator_successors(&func.blocks[idx].terminator) { + dfs_postorder(func, succ, block_idx, visited, postorder); + } + postorder.push(block_id); +} + +/// Compute the immediate dominator of each block using Cooper/Harvey/Kennedy. +/// +/// Returns `idom[b] = immediate dominator of b`, with `idom[entry] = entry`. +fn compute_idoms(func: &IrFunction) -> HashMap { + let rpo = compute_rpo(func); + // rpo_num[b] = index in RPO order (entry = 0, smallest index = processed first) + let rpo_num: HashMap = + rpo.iter().enumerate().map(|(i, &b)| (b, i)).collect(); + + let preds = build_predecessors(func); + let entry = func.entry_block; + + let mut idom: HashMap = HashMap::new(); + idom.insert(entry, entry); + + let mut changed = true; + while changed { + changed = false; + // Process blocks in RPO order, skipping the entry. + for &b in rpo.iter().skip(1) { + let block_preds = &preds[&b]; + + // Start with the first predecessor that already has an idom assigned. + let mut new_idom = match block_preds + .iter() + .filter(|&&p| idom.contains_key(&p)) + .min_by_key(|&&p| rpo_num[&p]) + { + Some(&p) => p, + None => continue, // unreachable block — skip + }; + + // Intersect (walk up dom tree) with all other processed predecessors. + for &p in block_preds { + if p != new_idom && idom.contains_key(&p) { + new_idom = intersect(p, new_idom, &idom, &rpo_num); + } + } + + if idom.get(&b) != Some(&new_idom) { + idom.insert(b, new_idom); + changed = true; + } + } + } + + idom +} + +/// Walk up both fingers until they meet — the standard Cooper intersect. +fn intersect( + mut a: BlockId, + mut b: BlockId, + idom: &HashMap, + rpo_num: &HashMap, +) -> BlockId { + while a != b { + while rpo_num[&a] > rpo_num[&b] { + a = idom[&a]; + } + while rpo_num[&b] > rpo_num[&a] { + b = idom[&b]; + } + } + a +} + +/// Build dominator-tree children from the `idom` map. +fn build_dom_children( + idom: &HashMap, + entry: BlockId, +) -> HashMap> { + let mut children: HashMap> = HashMap::new(); + for (&b, &d) in idom { + if b != entry { + children.entry(d).or_default().push(b); + } + } + // Sort children for deterministic output. + for v in children.values_mut() { + v.sort_unstable_by_key(|id| id.0); + } + children +} + +// ── GVN walk ───────────────────────────────────────────────────────────────── + +/// Recursively walk the dominator tree in preorder. +/// +/// `value_map` is a flat map that acts as a scoped table: on entry we insert +/// new keys (recording them in `frame_keys`), on exit we remove them, restoring +/// the parent scope. Any key already present in `value_map` when we visit a +/// block was computed in a dominating block — safe to reuse. +fn collect_replacements( + func: &IrFunction, + block_id: BlockId, + dom_children: &HashMap>, + block_idx: &HashMap, + multi_def_vars: &HashSet, + value_map: &mut HashMap, + replacements: &mut HashMap, +) { + let idx = match block_idx.get(&block_id) { + Some(&i) => i, + None => return, + }; + + let mut frame_keys: Vec = Vec::new(); + + for instr in &func.blocks[idx].instructions { + match instr { + IrInstr::Const { dest, value } => { + // A multiply-defined dest (loop phi var) must be skipped + // entirely: adding it to replacements would replace ALL of + // its definitions with Assign(first), clobbering back-edge + // updates; inserting it into value_map would let dominated + // blocks wrongly reuse a value that changes each iteration. + if multi_def_vars.contains(dest) { + continue; + } + let key = ValueKey::Const(ConstKey::from(*value)); + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + IrInstr::BinOp { dest, op, lhs, rhs, .. } => { + // Skip if dest is multiply-defined (same reason as Const). + // Also skip if any operand is multiply-defined: a loop phi + // var carries different values per iteration, so the same + // BinOp in two dominated blocks can produce different results. + if multi_def_vars.contains(dest) + || multi_def_vars.contains(lhs) + || multi_def_vars.contains(rhs) + { + continue; + } + let key = binop_key(*op, *lhs, *rhs); + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + IrInstr::UnOp { dest, op, operand } => { + if multi_def_vars.contains(dest) || multi_def_vars.contains(operand) { + continue; + } + let key = ValueKey::UnOp { op: *op, operand: *operand }; + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + _ => {} + } + } + + // Recurse into dominated children. + if let Some(children) = dom_children.get(&block_id) { + for &child in children { + collect_replacements( + func, + child, + dom_children, + block_idx, + multi_def_vars, + value_map, + replacements, + ); + } + } + + // Pop this block's scope. + for key in frame_keys { + value_map.remove(&key); + } +} + +// ── Pass entry point ───────────────────────────────────────────────────────── + +/// Eliminates common subexpressions across basic blocks using the dominator tree. +pub fn eliminate(func: &mut IrFunction) { + if func.blocks.len() < 2 { + return; // nothing to do for single-block functions (local_cse covers those) + } + + let idom = compute_idoms(func); + let dom_children = build_dom_children(&idom, func.entry_block); + let block_idx: HashMap = + func.blocks.iter().enumerate().map(|(i, b)| (b.id, i)).collect(); + + let multi_def_vars = build_multi_def_vars(func); + let mut value_map: HashMap = HashMap::new(); + let mut replacements: HashMap = HashMap::new(); + + collect_replacements( + func, + func.entry_block, + &dom_children, + &block_idx, + &multi_def_vars, + &mut value_map, + &mut replacements, + ); + + if replacements.is_empty() { + return; + } + + for block in &mut func.blocks { + for instr in &mut block.instructions { + if let Some(dest) = instr_dest(instr) { + if let Some(&src) = replacements.get(&dest) { + *instr = IrInstr::Assign { dest, src }; + } + } + } + } + + prune_dead_locals(func); +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{IrBlock, IrTerminator, IrValue, TypeIdx, WasmType}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + /// Entry (B0) → B1: const duplicated across the edge. + /// B0 dominates B1, so the duplicate in B1 should be replaced with Assign. + #[test] + fn cross_block_const_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { dest: VarId(0), value: IrValue::I32(42) }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(42) }], + terminator: IrTerminator::Return { value: Some(VarId(1)) }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + assert!( + matches!(func.blocks[0].instructions[0], IrInstr::Const { dest: VarId(0), .. }), + "first definition should stay as Const" + ); + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { dest: VarId(1), src: VarId(0) } + ), + "dominated duplicate should become Assign" + ); + } + + /// Entry (B0) → B1: BinOp duplicated across the edge. + #[test] + fn cross_block_binop_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Return { value: Some(VarId(3)) }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + + eliminate(&mut func); + + assert!(matches!(func.blocks[0].instructions[0], IrInstr::BinOp { .. })); + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { dest: VarId(3), src: VarId(2) } + ), + "dominated duplicate BinOp should become Assign" + ); + } + + /// B0 branches to B1 and B2 (diamond). B1 and B2 don't dominate each other, + /// so a const in B1 should NOT eliminate the same const in B2. + #[test] + fn sibling_blocks_not_deduplicated() { + // B0 → B1, B0 → B2, both converge to B3 + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(7) }], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }; + let b2 = IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { dest: VarId(2), value: IrValue::I32(7) }], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }; + let b3 = IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }; + let mut func = make_func(vec![b0, b1, b2, b3]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + + eliminate(&mut func); + + // Both consts should remain — neither block dominates the other. + assert!( + matches!(func.blocks[1].instructions[0], IrInstr::Const { dest: VarId(1), .. }), + "const in B1 must not be eliminated" + ); + assert!( + matches!(func.blocks[2].instructions[0], IrInstr::Const { dest: VarId(2), .. }), + "const in B2 must not be eliminated" + ); + } + + /// A const defined in B0 (entry) should be reused in a deeply dominated block. + #[test] + fn deep_domination_chain() { + // B0 → B1 → B2: const defined in B0, duplicated in B2 + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { dest: VarId(0), value: IrValue::I32(99) }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }; + let b2 = IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(99) }], + terminator: IrTerminator::Return { value: Some(VarId(1)) }, + }; + let mut func = make_func(vec![b0, b1, b2]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + assert!( + matches!(func.blocks[0].instructions[0], IrInstr::Const { dest: VarId(0), .. }) + ); + assert!( + matches!( + func.blocks[2].instructions[0], + IrInstr::Assign { dest: VarId(1), src: VarId(0) } + ), + "deeply dominated duplicate should be eliminated" + ); + } + + /// Commutative BinOps with swapped operands in a dominated block should be deduped. + #[test] + fn cross_block_commutative_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Mul, + lhs: VarId(1), // swapped + rhs: VarId(0), + }], + terminator: IrTerminator::Return { value: Some(VarId(3)) }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + + eliminate(&mut func); + + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { dest: VarId(3), src: VarId(2) } + ), + "commutative cross-block BinOp should be deduplicated" + ); + } + + /// Single-block functions are skipped entirely (handled by local_cse). + #[test] + fn single_block_function_unchanged() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { dest: VarId(0), value: IrValue::I32(1) }, + IrInstr::Const { dest: VarId(1), value: IrValue::I32(1) }, + ], + terminator: IrTerminator::Return { value: None }, + }; + let mut func = make_func(vec![b0]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + // GVN skips single-block functions; duplicates remain (local_cse's job). + assert!(matches!(func.blocks[0].instructions[0], IrInstr::Const { .. })); + assert!(matches!(func.blocks[0].instructions[1], IrInstr::Const { .. })); + } +} diff --git a/crates/herkos-core/src/optimizer/licm.rs b/crates/herkos-core/src/optimizer/licm.rs new file mode 100644 index 0000000..4def516 --- /dev/null +++ b/crates/herkos-core/src/optimizer/licm.rs @@ -0,0 +1,1306 @@ +//! Loop-invariant code motion (LICM). +//! +//! Identifies instructions in loop headers whose operands don't change across +//! iterations, and moves them to a preheader block. +//! +//! ## Algorithm +//! +//! 1. Compute dominators (iterative algorithm) +//! 2. Find back edges: edge (src → tgt) where tgt dominates src +//! 3. Find natural loops: for each back edge, collect all blocks that reach +//! the source without going through the header +//! 4. For each loop, identify invariant instructions in the header (fixpoint): +//! - `Const` — trivially invariant +//! - `BinOp`, `UnOp`, `Assign`, `Select` — invariant if all operands are +//! defined outside the loop or by other invariant instructions +//! - Skip: `Load`, `Store`, `Call*`, `Global*`, `Memory*` +//! 5. Create or reuse a preheader block and move invariant instructions there +//! +//! **V1 simplification:** only hoists from the loop header block (which +//! dominates all loop blocks by definition). + +use super::utils::{ + build_predecessors, for_each_use, instr_dest, rewrite_terminator_target, terminator_successors, +}; +use crate::ir::{BlockId, IrBlock, IrFunction, IrInstr, IrTerminator, VarId}; +use std::collections::{HashMap, HashSet}; + +/// Run loop-invariant code motion on `func`. +pub fn eliminate(func: &mut IrFunction) { + if func.blocks.len() < 2 { + return; + } + + let preds = build_predecessors(func); + let dominators = compute_dominators(func, &preds); + let back_edges = find_back_edges(func, &dominators); + + if back_edges.is_empty() { + return; + } + + let loops = find_natural_loops(&back_edges, &preds); + + for (header, loop_blocks) in &loops { + hoist_invariants(func, *header, loop_blocks); + } +} + +// ── Dominator computation ──────────────────────────────────────────────────── + +/// Compute the dominator set for each block using the iterative algorithm. +/// +/// Returns a map from each block to the set of blocks that dominate it. +fn compute_dominators( + func: &IrFunction, + preds: &HashMap>, +) -> HashMap> { + let entry = func.entry_block; + let all_block_ids: HashSet = func.blocks.iter().map(|b| b.id).collect(); + + let mut dom: HashMap> = HashMap::new(); + dom.insert(entry, HashSet::from([entry])); + + for block in &func.blocks { + if block.id != entry { + dom.insert(block.id, all_block_ids.clone()); + } + } + + loop { + let mut changed = false; + for block in &func.blocks { + if block.id == entry { + continue; + } + let pred_set = &preds[&block.id]; + if pred_set.is_empty() { + continue; + } + + // new_dom = {self} ∪ ∩(dom[p] for p in preds) + let mut new_dom: Option> = None; + for p in pred_set { + if let Some(p_dom) = dom.get(p) { + new_dom = Some(match new_dom { + None => p_dom.clone(), + Some(current) => current.intersection(p_dom).copied().collect(), + }); + } + } + + let mut new_dom = new_dom.unwrap_or_default(); + new_dom.insert(block.id); + + if new_dom != dom[&block.id] { + dom.insert(block.id, new_dom); + changed = true; + } + } + if !changed { + break; + } + } + + dom +} + +// ── Back edge detection ────────────────────────────────────────────────────── + +/// Find all back edges in the CFG. +/// +/// A back edge is (src, tgt) where tgt dominates src. +fn find_back_edges( + func: &IrFunction, + dominators: &HashMap>, +) -> Vec<(BlockId, BlockId)> { + let mut back_edges = Vec::new(); + for block in &func.blocks { + for succ in terminator_successors(&block.terminator) { + if dominators + .get(&block.id) + .is_some_and(|dom_set| dom_set.contains(&succ)) + { + back_edges.push((block.id, succ)); + } + } + } + back_edges +} + +// ── Natural loop detection ─────────────────────────────────────────────────── + +/// Find natural loops from back edges. +/// +/// For each back edge (src → header), collects all blocks that can reach `src` +/// without going through `header`. Multiple back edges with the same header +/// are merged into one loop. +fn find_natural_loops( + back_edges: &[(BlockId, BlockId)], + preds: &HashMap>, +) -> Vec<(BlockId, HashSet)> { + let mut loops: HashMap> = HashMap::new(); + + for &(src, header) in back_edges { + let loop_blocks = loops.entry(header).or_insert_with(|| { + let mut set = HashSet::new(); + set.insert(header); + set + }); + + let mut worklist = vec![src]; + while let Some(n) = worklist.pop() { + if loop_blocks.insert(n) { + if let Some(n_preds) = preds.get(&n) { + for &p in n_preds { + worklist.push(p); + } + } + } + } + } + + loops.into_iter().collect() +} + +// ── Invariant identification & hoisting ────────────────────────────────────── + +/// Returns `true` if the instruction type is eligible for LICM hoisting. +/// +/// Only pure, side-effect-free computations are hoistable. Instructions that +/// depend on mutable state (`Global*`, `Memory*`) or have side effects +/// (`Load`, `Store`, `Call*`) are excluded. +fn is_licm_hoistable(instr: &IrInstr) -> bool { + matches!( + instr, + IrInstr::Const { .. } + | IrInstr::BinOp { .. } + | IrInstr::UnOp { .. } + | IrInstr::Assign { .. } + | IrInstr::Select { .. } + ) +} + +/// Identify loop-invariant instructions in the header and hoist them to a preheader. +fn hoist_invariants(func: &mut IrFunction, header: BlockId, loop_blocks: &HashSet) { + let header_idx = match func.blocks.iter().position(|b| b.id == header) { + Some(idx) => idx, + None => return, + }; + + // Collect all VarIds defined in any loop block. + let mut loop_defs: HashSet = HashSet::new(); + for block in &func.blocks { + if loop_blocks.contains(&block.id) { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + loop_defs.insert(dest); + } + } + } + } + + // Fixpoint: identify invariant instructions in the header. + let mut invariant_dests: HashSet = HashSet::new(); + loop { + let mut changed = false; + for instr in &func.blocks[header_idx].instructions { + if !is_licm_hoistable(instr) { + continue; + } + let dest = match instr_dest(instr) { + Some(d) => d, + None => continue, + }; + if invariant_dests.contains(&dest) { + continue; + } + + let mut all_ops_invariant = true; + for_each_use(instr, |v| { + if loop_defs.contains(&v) && !invariant_dests.contains(&v) { + all_ops_invariant = false; + } + }); + + if all_ops_invariant { + invariant_dests.insert(dest); + changed = true; + } + } + if !changed { + break; + } + } + + if invariant_dests.is_empty() { + return; + } + + // Find or create preheader. + let preheader_id = find_or_create_preheader(func, header, loop_blocks); + + // Re-lookup indices after possible block insertion. + let header_idx = func.blocks.iter().position(|b| b.id == header).unwrap(); + let preheader_idx = func + .blocks + .iter() + .position(|b| b.id == preheader_id) + .unwrap(); + + // Move invariant instructions from header to preheader (in order). + let mut hoisted = Vec::new(); + let mut remaining = Vec::new(); + + for instr in func.blocks[header_idx].instructions.drain(..) { + if let Some(dest) = instr_dest(&instr) { + if invariant_dests.contains(&dest) { + hoisted.push(instr); + continue; + } + } + remaining.push(instr); + } + + func.blocks[header_idx].instructions = remaining; + func.blocks[preheader_idx].instructions.extend(hoisted); +} + +/// Allocate a fresh `BlockId` that doesn't conflict with existing blocks. +fn fresh_block_id(func: &IrFunction) -> BlockId { + let max_id = func.blocks.iter().map(|b| b.id.0).max().unwrap_or(0); + BlockId(max_id + 1) +} + +/// Find an existing preheader or create a new one. +/// +/// A preheader is reused if it is the sole non-loop predecessor and ends +/// with an unconditional jump to the header. Otherwise a new preheader +/// block is created and non-loop predecessors are redirected to it. +fn find_or_create_preheader( + func: &mut IrFunction, + header: BlockId, + loop_blocks: &HashSet, +) -> BlockId { + let preds = build_predecessors(func); + let header_preds = &preds[&header]; + let non_loop_preds: Vec = header_preds + .iter() + .filter(|p| !loop_blocks.contains(p)) + .copied() + .collect(); + + if non_loop_preds.is_empty() { + // Header has no non-loop predecessors (entry block or unreachable from outside). + let preheader_id = fresh_block_id(func); + func.blocks.push(IrBlock { + id: preheader_id, + instructions: vec![], + terminator: IrTerminator::Jump { target: header }, + }); + if header == func.entry_block { + func.entry_block = preheader_id; + } + return preheader_id; + } + + // Reuse if single non-loop predecessor with unconditional jump to header. + if non_loop_preds.len() == 1 { + let pred = non_loop_preds[0]; + let pred_idx = func.blocks.iter().position(|b| b.id == pred).unwrap(); + if matches!(func.blocks[pred_idx].terminator, IrTerminator::Jump { target } if target == header) + { + return pred; + } + } + + // Create a new preheader and redirect non-loop predecessors. + let preheader_id = fresh_block_id(func); + func.blocks.push(IrBlock { + id: preheader_id, + instructions: vec![], + terminator: IrTerminator::Jump { target: header }, + }); + + for block in &mut func.blocks { + if non_loop_preds.contains(&block.id) { + rewrite_terminator_target(&mut block.terminator, header, preheader_id); + } + } + + preheader_id +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + BinOp, IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, TypeIdx, VarId, WasmType, + }; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + // ── No loops → no changes ──────────────────────────────────────────── + + #[test] + fn no_loop_no_change() { + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(0)), + }, + }, + ]); + + eliminate(&mut func); + + // No loops, so the const stays in block 0. + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + )); + } + + // ── Simple loop: const in header → hoisted to preheader ────────────── + + #[test] + fn simple_loop_const_hoisted() { + // B0 (entry): Jump(B1) + // B1 (header): v0 = Const(42), BranchIf(v1, B2, B3) + // B2 (body): Jump(B1) ← back edge + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // B0 is the sole non-loop predecessor with Jump → reused as preheader. + // v0 = Const(42) should be hoisted to B0. + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + + // B1 (header) should have no instructions. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── BinOp with operands from outside loop → hoisted ────────────────── + + #[test] + fn invariant_binop_hoisted() { + // B0 (entry): v0 = Const(10), v1 = Const(20), Jump(B1) + // B1 (header): v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // B2 (body): Jump(B1) + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(10), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(20), + }, + ], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v2 = BinOp should be hoisted to B0 (preheader). + assert_eq!(func.blocks[0].instructions.len(), 3); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + .. + } + )); + + // Header should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Chained invariants: const → binop using that const ─────────────── + + #[test] + fn chained_invariants_hoisted() { + // B0 (entry): v0 = Const(10), Jump(B1) + // B1 (header): v1 = Const(65536), v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // B2 (body): Jump(B1) + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(10), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(65536), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // Both v1 = Const and v2 = BinOp should be hoisted to B0. + // B0 now has: v0 = Const(10), v1 = Const(65536), v2 = Add(v0, v1). + assert_eq!(func.blocks[0].instructions.len(), 3); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(65536), + } + )); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + .. + } + )); + + // Header should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Non-hoistable instructions stay in the header ──────────────────── + + #[test] + fn side_effectful_not_hoisted() { + use crate::ir::MemoryAccessWidth; + + // B0: Jump(B1) + // B1 (header): v0 = Const(0), v1 = Load(v0), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }, + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v0 = Const is hoisted (invariant), but Load stays (not hoistable). + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── BinOp with operand from loop body → NOT hoisted ────────────────── + + #[test] + fn loop_dependent_not_hoisted() { + // B0: v0 = Const(1), Jump(B1) + // B1 (header): v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // v1 is defined in B2 (loop body) → v2 is NOT invariant + // B2: v1 = Const(5), Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(5), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v2 = BinOp should NOT be hoisted because v1 is defined in B2 (loop body). + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::BinOp { .. })); + } + + // ── Preheader reuse: single non-loop predecessor with Jump ─────────── + + #[test] + fn preheader_reused_when_possible() { + // B0 (entry): v0 = Const(99), Jump(B1) + // B1 (header): v1 = Const(42), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(99), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // B0 should be reused as preheader (sole non-loop pred with Jump). + // No new blocks should be created. + assert_eq!(func.blocks.len(), 4); + assert_eq!(func.blocks[0].instructions.len(), 2); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(99), + } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + } + )); + } + + // ── Preheader creation: multiple non-loop predecessors ─────────────── + + #[test] + fn preheader_created_when_needed() { + // B0 (entry): BranchIf(v0, B1, B2) + // B1: Jump(B3) + // B2: Jump(B3) + // B3 (header): v1 = Const(42), BranchIf(v2, B4, B5) + // B4 (body): Jump(B3) ← back edge + // B5 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(4), + if_false: BlockId(5), + }, + }, + IrBlock { + id: BlockId(4), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(5), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // A new preheader (B6) should be created. + assert_eq!(func.blocks.len(), 7); + + let preheader = func.blocks.iter().find(|b| b.id == BlockId(6)).unwrap(); + assert_eq!(preheader.instructions.len(), 1); + assert!(matches!( + preheader.instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + } + )); + assert!(matches!( + preheader.terminator, + IrTerminator::Jump { target: BlockId(3) } + )); + + // B1 and B2 should now jump to the preheader (B6). + let b1 = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert!(matches!( + b1.terminator, + IrTerminator::Jump { target: BlockId(6) } + )); + let b2 = func.blocks.iter().find(|b| b.id == BlockId(2)).unwrap(); + assert!(matches!( + b2.terminator, + IrTerminator::Jump { target: BlockId(6) } + )); + + // Header (B3) should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(3)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── GlobalGet not hoisted (depends on mutable state) ───────────────── + + #[test] + fn global_get_not_hoisted() { + use crate::ir::GlobalIdx; + + // B0: Jump(B1) + // B1 (header): v0 = GlobalGet(0), BranchIf(v1, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::GlobalGet { + dest: VarId(0), + index: GlobalIdx::new(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // GlobalGet should NOT be hoisted (mutable global may change each iteration). + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::GlobalGet { .. })); + } + + // ── Self-loop: header is also the back-edge source ─────────────────── + + #[test] + fn self_loop_const_hoisted() { + // B0: Jump(B1) + // B1: v0 = Const(42), BranchIf(v1, B1, B2) ← self-loop + // B2: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // Const should be hoisted to B0 (preheader). + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── No invariant instructions → no changes ─────────────────────────── + + #[test] + fn no_invariants_no_change() { + use crate::ir::MemoryAccessWidth; + + // B0: v0 = Const(0), Jump(B1) + // B1 (header): v1 = Load(v0), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // No invariants to hoist — no new blocks, header unchanged. + assert_eq!(func.blocks.len(), 4); + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── Entry block as loop header ─────────────────────────────────────── + + #[test] + fn entry_block_loop_header() { + // B0 (entry/header): v0 = Const(42), BranchIf(v1, B1, B2) + // B1 (body): Jump(B0) ← back edge + // B2 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(0) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // A preheader should be created, and entry_block updated. + assert_eq!(func.blocks.len(), 4); + let preheader_id = func.entry_block; + assert_ne!(preheader_id, BlockId(0)); + + let preheader = func.blocks.iter().find(|b| b.id == preheader_id).unwrap(); + assert_eq!(preheader.instructions.len(), 1); + assert!(matches!( + preheader.instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + assert!(matches!( + preheader.terminator, + IrTerminator::Jump { target: BlockId(0) } + )); + + // Original header (B0) should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(0)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Mixed: some hoistable, some not ────────────────────────────────── + + #[test] + fn mixed_hoistable_and_non_hoistable() { + use crate::ir::MemoryAccessWidth; + + // B0: Jump(B1) + // B1 (header): v0 = Const(100), v1 = Load(v0), v2 = Const(200) + // BranchIf(v3, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(100), + }, + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(200), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v0 and v2 (Consts) should be hoisted; Load stays. + assert_eq!(func.blocks[0].instructions.len(), 2); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(100), + } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(200), + } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── Single-block function → no change ──────────────────────────────── + + #[test] + fn single_block_function_no_change() { + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Return { + value: Some(VarId(0)), + }, + }]); + + eliminate(&mut func); + + assert_eq!(func.blocks.len(), 1); + assert_eq!(func.blocks[0].instructions.len(), 1); + } + + // ── Dominator computation tests ────────────────────────────────────── + + #[test] + fn dominators_linear_chain() { + // B0 → B1 → B2 + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + + assert_eq!(dom[&BlockId(0)], HashSet::from([BlockId(0)])); + assert_eq!(dom[&BlockId(1)], HashSet::from([BlockId(0), BlockId(1)])); + assert_eq!( + dom[&BlockId(2)], + HashSet::from([BlockId(0), BlockId(1), BlockId(2)]) + ); + } + + #[test] + fn dominators_diamond() { + // B0 → B1, B0 → B2, B1 → B3, B2 → B3 + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + + // B3 is dominated by B0 (only common dominator of B1 and B2). + assert_eq!(dom[&BlockId(3)], HashSet::from([BlockId(0), BlockId(3)])); + } + + #[test] + fn back_edges_detected() { + // B0 → B1 → B2 → B1 (back edge) + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + let back_edges = find_back_edges(&func, &dom); + + assert_eq!(back_edges.len(), 1); + assert_eq!(back_edges[0], (BlockId(2), BlockId(1))); + } + + #[test] + fn natural_loop_blocks() { + // B0 → B1 → B2 → B3 → B1 (back edge) + // Loop = {B1, B2, B3} + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + let back_edges = find_back_edges(&func, &dom); + let loops = find_natural_loops(&back_edges, &preds); + + assert_eq!(loops.len(), 1); + let (header, loop_blocks) = &loops[0]; + assert_eq!(*header, BlockId(1)); + assert_eq!( + *loop_blocks, + HashSet::from([BlockId(1), BlockId(2), BlockId(3)]) + ); + } +} diff --git a/crates/herkos-core/src/optimizer/local_cse.rs b/crates/herkos-core/src/optimizer/local_cse.rs new file mode 100644 index 0000000..d431a19 --- /dev/null +++ b/crates/herkos-core/src/optimizer/local_cse.rs @@ -0,0 +1,574 @@ +//! Local common subexpression elimination (CSE) via value numbering. +//! +//! Within each block, identifies identical computations and replaces duplicates +//! with references to the first result. Only side-effect-free instructions are +//! considered (`BinOp`, `UnOp`, `Const`). Duplicates are replaced with +//! `Assign { dest, src: previous_result }`, which copy propagation cleans up. + +use crate::ir::{BinOp, IrFunction, IrInstr, IrValue, UnOp, VarId}; +use std::collections::HashMap; + +use super::utils::prune_dead_locals; + +// ── Value key ──────────────────────────────────────────────────────────────── + +/// Hashable representation of a pure computation for deduplication. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ValueKey { + /// Constant value (using bit-level equality for floats). + Const(ConstKey), + + /// Binary operation with operand variable IDs. + BinOp { op: BinOp, lhs: VarId, rhs: VarId }, + + /// Unary operation with operand variable ID. + UnOp { op: UnOp, operand: VarId }, +} + +/// Bit-level constant key that implements Eq/Hash correctly for floats. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ConstKey { + I32(i32), + I64(i64), + F32(u32), + F64(u64), +} + +impl From for ConstKey { + fn from(v: IrValue) -> Self { + match v { + IrValue::I32(x) => ConstKey::I32(x), + IrValue::I64(x) => ConstKey::I64(x), + IrValue::F32(x) => ConstKey::F32(x.to_bits()), + IrValue::F64(x) => ConstKey::F64(x.to_bits()), + } + } +} + +// ── Commutative op detection ───────────────────────────────────────────────── + +/// Returns true for operations where `op(a, b) == op(b, a)`. +fn is_commutative(op: &BinOp) -> bool { + matches!( + op, + BinOp::I32Add + | BinOp::I32Mul + | BinOp::I32And + | BinOp::I32Or + | BinOp::I32Xor + | BinOp::I32Eq + | BinOp::I32Ne + | BinOp::I64Add + | BinOp::I64Mul + | BinOp::I64And + | BinOp::I64Or + | BinOp::I64Xor + | BinOp::I64Eq + | BinOp::I64Ne + | BinOp::F32Add + | BinOp::F32Mul + | BinOp::F32Eq + | BinOp::F32Ne + | BinOp::F64Add + | BinOp::F64Mul + | BinOp::F64Eq + | BinOp::F64Ne + ) +} + +/// Build a `ValueKey` for a `BinOp`, normalizing operand order for commutative ops. +fn binop_key(op: BinOp, lhs: VarId, rhs: VarId) -> ValueKey { + let (lhs, rhs) = if is_commutative(&op) && lhs.0 > rhs.0 { + (rhs, lhs) + } else { + (lhs, rhs) + }; + ValueKey::BinOp { op, lhs, rhs } +} + +// ── Pass entry point ───────────────────────────────────────────────────────── + +/// Eliminates common subexpressions within each block of `func`. +pub fn eliminate(func: &mut IrFunction) { + let mut changed = false; + + for block in &mut func.blocks { + // Maps a pure computation to the first VarId that computed it. + let mut value_map: HashMap = HashMap::new(); + + for instr in &mut block.instructions { + // In strict SSA form each variable is defined exactly once, so there + // is no need to invalidate cached CSE entries on redefinition. + match instr { + IrInstr::Const { dest, value } => { + let key = ValueKey::Const(ConstKey::from(*value)); + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + IrInstr::BinOp { + dest, op, lhs, rhs, .. + } => { + let key = binop_key(*op, *lhs, *rhs); + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + IrInstr::UnOp { dest, op, operand } => { + let key = ValueKey::UnOp { + op: *op, + operand: *operand, + }; + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + // All other instructions are not eligible for CSE. + _ => {} + } + } + } + + if changed { + prune_dead_locals(func); + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, IrTerminator, TypeIdx, WasmType}; + + /// Helper: create a minimal IrFunction with the given blocks. + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + /// Helper: create a block with given instructions and a simple return terminator. + fn make_block(id: u32, instructions: Vec) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions, + terminator: IrTerminator::Return { value: None }, + } + } + + #[test] + fn duplicate_binop_is_eliminated() { + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + let block = &func.blocks[0]; + assert!(matches!(block.instructions[0], IrInstr::BinOp { .. })); + assert!( + matches!( + block.instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "Duplicate BinOp should be replaced with Assign" + ); + } + + #[test] + fn commutative_binop_is_deduplicated() { + // v2 = v0 + v1, v3 = v1 + v0 → v3 should become Assign from v2 + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "Commutative BinOp with swapped operands should be deduplicated" + ); + } + + #[test] + fn non_commutative_binop_not_deduplicated() { + // v2 = v0 - v1, v3 = v1 - v0 → different computations, keep both + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Sub, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Sub, + lhs: VarId(1), + rhs: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::BinOp { .. }), + "Non-commutative BinOp with swapped operands should NOT be deduplicated" + ); + } + + #[test] + fn duplicate_unop_is_eliminated() { + let instrs = vec![ + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Clz, + operand: VarId(0), + }, + IrInstr::UnOp { + dest: VarId(2), + op: UnOp::I32Clz, + operand: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(1) + } + ), + "Duplicate UnOp should be replaced with Assign" + ); + } + + #[test] + fn duplicate_const_is_eliminated() { + let instrs = vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + ), + "Duplicate Const should be replaced with Assign" + ); + } + + #[test] + fn float_const_nan_bits_handled() { + // Two NaN constants with the same bit pattern should be deduplicated. + let instrs = vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::F32(f32::NAN), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::F32(f32::NAN), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(0), WasmType::F32), (VarId(1), WasmType::F32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + ), + "NaN constants with same bit pattern should be deduplicated" + ); + } + + #[test] + fn different_ops_not_deduplicated() { + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Sub, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::BinOp { .. }), + "Different operations should not be deduplicated" + ); + } + + #[test] + fn cross_block_not_deduplicated() { + // Each block should have its own value map — no cross-block CSE. + let block0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let block1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Return { value: None }, + }; + + let mut func = make_func(vec![block0, block1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + // Both should remain as BinOp (no cross-block elimination). + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[1].instructions[0], IrInstr::BinOp { .. }), + "Cross-block duplicate should NOT be eliminated" + ); + } + + /// In strict SSA form every variable is defined exactly once within a block, + /// so (v0 + v1) always refers to the same computation and can be CSE'd. + #[test] + fn ssa_unique_defs_allow_cse() { + // v2 = v0 + v1 ← first occurrence + // v3 = v0 + v1 ← identical keys with same VarIds → should be eliminated + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + // v3 should be eliminated to Assign(v3, v2). + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "duplicate (v0 + v1) should be CSE'd to Assign in strict SSA" + ); + } + + #[test] + fn side_effect_instructions_not_eliminated() { + // Load, Store, Call, etc. should never be CSE'd. + use crate::ir::MemoryAccessWidth; + + let instrs = vec![ + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + IrInstr::Load { + dest: VarId(2), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Load { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::Load { .. }), + "Load instructions should not be CSE'd" + ); + } + + #[test] + fn triple_duplicate_eliminates_both() { + // Three identical BinOps: second and third should become Assigns to first. + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(4), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![ + (VarId(2), WasmType::I32), + (VarId(3), WasmType::I32), + (VarId(4), WasmType::I32), + ]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + )); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::Assign { + dest: VarId(4), + src: VarId(2) + } + )); + } +} diff --git a/crates/herkos-core/src/optimizer/merge_blocks.rs b/crates/herkos-core/src/optimizer/merge_blocks.rs new file mode 100644 index 0000000..4850df1 --- /dev/null +++ b/crates/herkos-core/src/optimizer/merge_blocks.rs @@ -0,0 +1,382 @@ +//! Single-predecessor block merging. +//! +//! When a block `B` has exactly one predecessor `P`, and `P` reaches `B` via an +//! unconditional `Jump`, then `B` can be appended to `P` — its instructions are +//! concatenated and `P` inherits `B`'s terminator. +//! +//! The pass iterates to a fixed point so that chains like +//! B0 → Jump → B1 → Jump → B2 → Return +//! collapse into a single block B0 → Return. +//! +//! After merging, absorbed blocks are removed from `func.blocks`. + +use super::utils::build_predecessors; +use crate::ir::{BlockId, IrFunction, IrTerminator}; +use std::collections::{HashMap, HashSet}; + +/// Merge single-predecessor blocks reached via unconditional `Jump`. +/// +/// Iterates to a fixed point, then removes absorbed blocks. +pub fn eliminate(func: &mut IrFunction) { + loop { + let preds = build_predecessors(func); + + // Index blocks by ID for lookup during merging. + let block_map: HashMap = func + .blocks + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + // Collect merge pairs: (predecessor_idx, target_idx) where target has + // exactly one predecessor and that predecessor reaches it via Jump. + let mut merges: Vec<(usize, usize)> = Vec::new(); + // Track which blocks are already involved in a merge this round to avoid + // conflicting operations (a block can't be both a merge source and target + // in the same round). + let mut involved: HashSet = HashSet::new(); + + for block in &func.blocks { + if let IrTerminator::Jump { target } = block.terminator { + // Skip self-loops. + if target == block.id { + continue; + } + // Never merge away the entry block. + if target == func.entry_block { + continue; + } + if let Some(pred_set) = preds.get(&target) { + if pred_set.len() == 1 { + let pred_idx = block_map[&block.id]; + let target_idx = block_map[&target]; + // Avoid conflicts: each block participates in at most one + // merge per round. + if !involved.contains(&pred_idx) && !involved.contains(&target_idx) { + merges.push((pred_idx, target_idx)); + involved.insert(pred_idx); + involved.insert(target_idx); + } + } + } + } + } + + if merges.is_empty() { + break; + } + + // Perform merges. We collect the target block data first to avoid borrow + // conflicts on func.blocks. + let absorbed_sorted = { + let mut absorbed: Vec = merges.iter().map(|(_, t)| *t).collect(); + absorbed.sort_unstable_by(|a, b| b.cmp(a)); + absorbed + }; + + for (pred_idx, target_idx) in &merges { + // Take target block's data out. + let target_instrs = std::mem::take(&mut func.blocks[*target_idx].instructions); + let target_term = std::mem::replace( + &mut func.blocks[*target_idx].terminator, + IrTerminator::Unreachable, + ); + // Append to predecessor. + func.blocks[*pred_idx].instructions.extend(target_instrs); + func.blocks[*pred_idx].terminator = target_term; + } + + // Remove absorbed blocks (iterate in reverse to preserve indices). + for idx in absorbed_sorted { + func.blocks.remove(idx); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, TypeIdx, VarId}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + fn block_ids(func: &IrFunction) -> Vec { + func.blocks.iter().map(|b| b.id.0).collect() + } + + fn instr_block(id: u32, dest: u32, val: i32, term: IrTerminator) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions: vec![IrInstr::Const { + dest: VarId(dest), + value: IrValue::I32(val), + }], + terminator: term, + } + } + + // ── Basic cases ────────────────────────────────────────────────────── + + #[test] + fn single_block_unchanged() { + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }]); + eliminate(&mut func); + assert_eq!(block_ids(&func), vec![0]); + } + + #[test] + fn linear_chain_collapses() { + // B0 → Jump → B1 → Jump → B2 → Return + let mut func = make_func(vec![ + instr_block(0, 0, 1, IrTerminator::Jump { target: BlockId(1) }), + instr_block(1, 1, 2, IrTerminator::Jump { target: BlockId(2) }), + instr_block( + 2, + 2, + 3, + IrTerminator::Return { + value: Some(VarId(2)), + }, + ), + ]); + eliminate(&mut func); + // All merged into B0. + assert_eq!(block_ids(&func), vec![0]); + assert_eq!(func.blocks[0].instructions.len(), 3); + assert!(matches!( + func.blocks[0].terminator, + IrTerminator::Return { + value: Some(VarId(2)) + } + )); + } + + #[test] + fn conditional_predecessor_not_merged() { + // B0: BranchIf → B1 / B2 — both have 1 predecessor but via conditional + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // Nothing merged — conditional edges are not Jump. + assert_eq!(block_ids(&func), vec![0, 1, 2]); + } + + #[test] + fn multiple_predecessors_not_merged() { + // B0 → Jump → B2, B1 → Jump → B2 — B2 has 2 predecessors + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + // B1 is dead (no predecessor), but merge_blocks doesn't remove dead blocks. + // B2 has 2 predecessors (B0 and B1) → not merged. + eliminate(&mut func); + assert_eq!(block_ids(&func), vec![0, 1, 2]); + } + + #[test] + fn self_loop_not_merged() { + // B0 → Jump → B0 (self-loop) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(0) }, + }]); + eliminate(&mut func); + assert_eq!(block_ids(&func), vec![0]); + } + + #[test] + fn entry_block_not_absorbed() { + // B1 → Jump → B0 (entry) — B0 has 1 predecessor but is entry + let mut func = IrFunction { + params: vec![], + locals: vec![], + blocks: vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(0) }, + }, + ], + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + }; + eliminate(&mut func); + // B0 is entry, must not be absorbed into B1. + assert!(func.blocks.iter().any(|b| b.id == BlockId(0))); + } + + // ── Fixed-point iteration ────────────────────────────────────────── + + #[test] + fn fixed_point_three_block_chain() { + // B0 → B1 → B2 → B3 → Return + // Round 1: B1→B0, B3→B2. Round 2: B2→B0. + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + assert_eq!(block_ids(&func), vec![0]); + assert!(matches!( + func.blocks[0].terminator, + IrTerminator::Return { value: None } + )); + } + + // ── Realistic pattern ────────────────────────────────────────────── + + #[test] + fn jump_then_branch_merges_prologue() { + // B0 → Jump → B1 → BranchIf(B2, B3) + // B2: Return, B3: Return + // B1 has 1 predecessor (B0) via Jump → merge. + let mut func = make_func(vec![ + instr_block(0, 0, 10, IrTerminator::Jump { target: BlockId(1) }), + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(20), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // B1 merged into B0. + assert_eq!(block_ids(&func), vec![0, 2, 3]); + assert_eq!(func.blocks[0].instructions.len(), 2); + assert!(matches!( + func.blocks[0].terminator, + IrTerminator::BranchIf { .. } + )); + } + + #[test] + fn loop_back_edge_prevents_merge() { + // B0 → Jump → B1 → BranchIf(B2, B3) + // B2 → Jump → B1 (back-edge, B1 now has 2 predecessors) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // B1 has 2 predecessors (B0 and B2) → not merged. + // No blocks are mergeable. + assert_eq!(block_ids(&func), vec![0, 1, 2, 3]); + } +} diff --git a/crates/herkos-core/src/optimizer/mod.rs b/crates/herkos-core/src/optimizer/mod.rs index 654382f..bf4424a 100644 --- a/crates/herkos-core/src/optimizer/mod.rs +++ b/crates/herkos-core/src/optimizer/mod.rs @@ -3,20 +3,31 @@ //! This module implements optimizations on the intermediate representation (IR) //! to improve code generation quality and runtime performance. //! -//! Each optimization is a self-contained sub-module. The top-level -//! [`optimize_ir`] function runs all passes in order. +//! Passes are split into two phases: +//! - **Pre-lowering** ([`optimize_ir`]): operates on SSA IR with phi nodes +//! - **Post-lowering** ([`optimize_lowered_ir`]): operates on lowered IR after phi destruction use crate::ir::{LoweredModuleInfo, ModuleInfo}; use anyhow::Result; -// ── Passes ─────────────────────────────────────────────────────────────────── +// ── Shared utilities ───────────────────────────────────────────────────────── +pub(crate) mod utils; + +// ── Pre-lowering passes ────────────────────────────────────────────────────── mod dead_blocks; +// ── Post-lowering passes ───────────────────────────────────────────────────── +mod dead_instrs; +mod empty_blocks; +mod merge_blocks; +mod branch_fold; +mod gvn; +mod licm; +mod local_cse; + /// Optimizes the pure SSA IR before phi lowering. /// /// Passes here operate on [`ModuleInfo`] with phi nodes still intact. -/// Control-flow based passes (e.g. dead block elimination) belong here -/// because reachability is identical in SSA and lowered form. pub fn optimize_ir(module_info: ModuleInfo, do_opt: bool) -> Result { let mut module_info = module_info; if do_opt { @@ -29,12 +40,30 @@ pub fn optimize_ir(module_info: ModuleInfo, do_opt: bool) -> Result /// Optimizes the lowered IR after phi nodes have been eliminated. /// -/// Passes here operate on [`LoweredModuleInfo`] where all `IrInstr::Phi` -/// nodes have been replaced by `IrInstr::Assign` in predecessor blocks. +/// Runs all post-lowering optimization passes: structural cleanup, redundancy +/// elimination (CSE/GVN), loop optimizations (LICM), and branch simplification. pub fn optimize_lowered_ir( module_info: LoweredModuleInfo, - _do_opt: bool, + do_opt: bool, ) -> Result { + let mut module_info = module_info; + if do_opt { + for func in &mut module_info.ir_functions { + for _ in 0..2 { + empty_blocks::eliminate(func); + dead_blocks::eliminate(func)?; + merge_blocks::eliminate(func); + dead_blocks::eliminate(func)?; + local_cse::eliminate(func); + gvn::eliminate(func); + dead_instrs::eliminate(func); + branch_fold::eliminate(func); + dead_instrs::eliminate(func); + licm::eliminate(func); + dead_instrs::eliminate(func); + } + } + } Ok(module_info) } diff --git a/crates/herkos-core/src/optimizer/utils.rs b/crates/herkos-core/src/optimizer/utils.rs new file mode 100644 index 0000000..7d9683c --- /dev/null +++ b/crates/herkos-core/src/optimizer/utils.rs @@ -0,0 +1,658 @@ +//! Shared utility functions for IR optimization passes. +//! +//! Provides common operations on IR instructions, terminators, and control flow +//! that are needed by multiple optimization passes. +//! +//! Some functions are not yet used by existing passes but are provided for +//! upcoming optimization passes (const_prop, dead_instrs, local_cse, licm). +#![allow(dead_code)] + +use crate::ir::{BinOp, BlockId, IrFunction, IrInstr, IrTerminator, UnOp, VarId}; +use std::collections::{HashMap, HashSet}; + +// ── Terminator successors ──────────────────────────────────────────────────── + +/// Returns the successor block IDs for a terminator. +pub fn terminator_successors(term: &IrTerminator) -> Vec { + match term { + IrTerminator::Return { .. } | IrTerminator::Unreachable => vec![], + IrTerminator::Jump { target } => vec![*target], + IrTerminator::BranchIf { + if_true, if_false, .. + } => vec![*if_true, *if_false], + IrTerminator::BranchTable { + targets, default, .. + } => targets + .iter() + .chain(std::iter::once(default)) + .copied() + .collect(), + } +} + +// ── Predecessor map ────────────────────────────────────────────────────────── + +/// Build a map from each block ID to the set of *distinct* predecessor block IDs. +pub fn build_predecessors(func: &IrFunction) -> HashMap> { + let mut preds: HashMap> = HashMap::new(); + // Ensure every block has an entry (even if no predecessors). + for block in &func.blocks { + preds.entry(block.id).or_default(); + } + for block in &func.blocks { + for succ in terminator_successors(&block.terminator) { + preds.entry(succ).or_default().insert(block.id); + } + } + preds +} + +// ── Instruction variable traversal ─────────────────────────────────────────── + +/// Calls `f` with every variable read by `instr`. +pub fn for_each_use(instr: &IrInstr, mut f: F) { + match instr { + IrInstr::Const { .. } => {} + IrInstr::BinOp { lhs, rhs, .. } => { + f(*lhs); + f(*rhs); + } + IrInstr::UnOp { operand, .. } => { + f(*operand); + } + IrInstr::Load { addr, .. } => { + f(*addr); + } + IrInstr::Store { addr, value, .. } => { + f(*addr); + f(*value); + } + IrInstr::Call { args, .. } | IrInstr::CallImport { args, .. } => { + for a in args { + f(*a); + } + } + IrInstr::CallIndirect { + table_idx, args, .. + } => { + f(*table_idx); + for a in args { + f(*a); + } + } + IrInstr::Assign { src, .. } => { + f(*src); + } + IrInstr::GlobalGet { .. } => {} + IrInstr::GlobalSet { value, .. } => { + f(*value); + } + IrInstr::MemorySize { .. } => {} + IrInstr::MemoryGrow { delta, .. } => { + f(*delta); + } + IrInstr::MemoryCopy { dst, src, len } => { + f(*dst); + f(*src); + f(*len); + } + IrInstr::Select { + val1, + val2, + condition, + .. + } => { + f(*val1); + f(*val2); + f(*condition); + } + IrInstr::MemoryFill { dst, val, len } => { f(*dst); f(*val); f(*len); } + IrInstr::MemoryInit { dst, src_offset, len, .. } => { f(*dst); f(*src_offset); f(*len); } + IrInstr::DataDrop { .. } => {} + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + } +} + +/// Calls `f` with every variable read by a block terminator. +pub fn for_each_use_terminator(term: &IrTerminator, mut f: F) { + match term { + IrTerminator::Return { value: Some(v) } => { + f(*v); + } + IrTerminator::Return { value: None } + | IrTerminator::Jump { .. } + | IrTerminator::Unreachable => {} + IrTerminator::BranchIf { condition, .. } => { + f(*condition); + } + IrTerminator::BranchTable { index, .. } => { + f(*index); + } + } +} + +// ── Instruction destination ────────────────────────────────────────────────── + +/// Returns the variable written by `instr`, or `None` for side-effect-only instructions. +pub fn instr_dest(instr: &IrInstr) -> Option { + match instr { + IrInstr::Const { dest, .. } + | IrInstr::BinOp { dest, .. } + | IrInstr::UnOp { dest, .. } + | IrInstr::Load { dest, .. } + | IrInstr::Assign { dest, .. } + | IrInstr::GlobalGet { dest, .. } + | IrInstr::MemorySize { dest } + | IrInstr::MemoryGrow { dest, .. } + | IrInstr::Select { dest, .. } => Some(*dest), + + IrInstr::Call { dest, .. } + | IrInstr::CallImport { dest, .. } + | IrInstr::CallIndirect { dest, .. } => *dest, + + IrInstr::Store { .. } | IrInstr::GlobalSet { .. } | IrInstr::MemoryCopy { .. } => None, + + IrInstr::MemoryFill { dst, val, len } => { f(*dst); f(*val); f(*len); } + IrInstr::MemoryInit { dst, src_offset, len, .. } => { f(*dst); f(*src_offset); f(*len); } + IrInstr::DataDrop { .. } => {} + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + } +} + +/// Redirects the destination variable of `instr` to `new_dest`. +/// +/// Only called when `instr_dest(instr)` is `Some(_)`, i.e. the instruction +/// produces a value. Instructions without a dest are left unchanged. +pub fn set_instr_dest(instr: &mut IrInstr, new_dest: VarId) { + match instr { + IrInstr::Const { dest, .. } + | IrInstr::BinOp { dest, .. } + | IrInstr::UnOp { dest, .. } + | IrInstr::Load { dest, .. } + | IrInstr::Assign { dest, .. } + | IrInstr::GlobalGet { dest, .. } + | IrInstr::MemorySize { dest } + | IrInstr::MemoryGrow { dest, .. } + | IrInstr::Select { dest, .. } => { + *dest = new_dest; + } + IrInstr::Call { dest, .. } + | IrInstr::CallImport { dest, .. } + | IrInstr::CallIndirect { dest, .. } => { + *dest = Some(new_dest); + } + // No dest — unreachable given precondition, but harmless to ignore. + IrInstr::Store { .. } | IrInstr::GlobalSet { .. } | IrInstr::MemoryCopy { .. } => {} + + IrInstr::MemoryFill { dst, val, len } => { f(*dst); f(*val); f(*len); } + IrInstr::MemoryInit { dst, src_offset, len, .. } => { f(*dst); f(*src_offset); f(*len); } + IrInstr::DataDrop { .. } => {} + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + } +} + +// ── Use-count helpers ──────────────────────────────────────────────────────── + +/// Count how many times `var` appears as an operand (read) in `instr`. +pub fn count_uses_of(instr: &IrInstr, var: VarId) -> usize { + let mut count = 0usize; + for_each_use(instr, |v| { + if v == var { + count += 1; + } + }); + count +} + +/// Count how many times `var` appears as an operand in `term`. +pub fn count_uses_of_terminator(term: &IrTerminator, var: VarId) -> usize { + let mut count = 0usize; + for_each_use_terminator(term, |v| { + if v == var { + count += 1; + } + }); + count +} + +// ── Use-replacement helpers ────────────────────────────────────────────────── + +/// Replace every read-occurrence of `old` with `new` in `instr`. +/// Only touches operand (source) slots; the destination slot is never modified. +pub fn replace_uses_of(instr: &mut IrInstr, old: VarId, new: VarId) { + let sub = |v: &mut VarId| { + if *v == old { + *v = new; + } + }; + match instr { + IrInstr::Const { .. } => {} + IrInstr::BinOp { lhs, rhs, .. } => { + sub(lhs); + sub(rhs); + } + IrInstr::UnOp { operand, .. } => { + sub(operand); + } + IrInstr::Load { addr, .. } => { + sub(addr); + } + IrInstr::Store { addr, value, .. } => { + sub(addr); + sub(value); + } + IrInstr::Call { args, .. } | IrInstr::CallImport { args, .. } => { + for a in args { + sub(a); + } + } + IrInstr::CallIndirect { + table_idx, args, .. + } => { + sub(table_idx); + for a in args { + sub(a); + } + } + IrInstr::Assign { src, .. } => { + sub(src); + } + IrInstr::GlobalGet { .. } => {} + IrInstr::GlobalSet { value, .. } => { + sub(value); + } + IrInstr::MemorySize { .. } => {} + IrInstr::MemoryGrow { delta, .. } => { + sub(delta); + } + IrInstr::MemoryCopy { dst, src, len } => { + sub(dst); + sub(src); + sub(len); + } + IrInstr::Select { + val1, + val2, + condition, + .. + } => { + sub(val1); + sub(val2); + sub(condition); + } + IrInstr::MemoryFill { dst, val, len } => { f(*dst); f(*val); f(*len); } + IrInstr::MemoryInit { dst, src_offset, len, .. } => { f(*dst); f(*src_offset); f(*len); } + IrInstr::DataDrop { .. } => {} + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + } +} + +/// Replace every read-occurrence of `old` with `new` in `term`. +pub fn replace_uses_of_terminator(term: &mut IrTerminator, old: VarId, new: VarId) { + let sub = |v: &mut VarId| { + if *v == old { + *v = new; + } + }; + match term { + IrTerminator::Return { value: Some(v) } => { + sub(v); + } + IrTerminator::Return { value: None } + | IrTerminator::Jump { .. } + | IrTerminator::Unreachable => {} + IrTerminator::BranchIf { condition, .. } => { + sub(condition); + } + IrTerminator::BranchTable { index, .. } => { + sub(index); + } + } +} + +// ── Global use-count ───────────────────────────────────────────────────────── + +/// Counts how many times each variable is *read* across the entire function +/// (all blocks, all instructions, all terminators). +pub fn build_global_use_count(func: &IrFunction) -> HashMap { + let mut counts: HashMap = HashMap::new(); + for block in &func.blocks { + for instr in &block.instructions { + for_each_use(instr, |v| { + *counts.entry(v).or_insert(0) += 1; + }); + } + for_each_use_terminator(&block.terminator, |v| { + *counts.entry(v).or_insert(0) += 1; + }); + } + counts +} + +// ── Dead-local pruning ─────────────────────────────────────────────────────── + +/// Remove from `func.locals` any variable that no longer appears in any +/// instruction or terminator of any block. +pub fn prune_dead_locals(func: &mut IrFunction) { + // Collect all variables still referenced anywhere in the function. + let mut live: HashSet = HashSet::new(); + + for block in &func.blocks { + for instr in &block.instructions { + for_each_use(instr, |v| { + live.insert(v); + }); + if let Some(dest) = instr_dest(instr) { + live.insert(dest); + } + } + for_each_use_terminator(&block.terminator, |v| { + live.insert(v); + }); + } + + // Keep params unconditionally; prune locals that are not in `live`. + func.locals.retain(|(var, _)| live.contains(var)); +} + +// ── Side-effect classification ─────────────────────────────────────────────── + +/// Returns `true` if the instruction is side-effect-free and can be safely +/// removed when its result is unused. +/// +/// Instructions that may trap (Load, MemoryGrow, integer div/rem, float-to-int +/// truncation), modify external state (Store, GlobalSet, MemoryCopy), or have +/// unknown effects (Call*) are considered side-effectful and must be retained +/// even if their result is unused — removing them would suppress a Wasm trap. +pub fn is_side_effect_free(instr: &IrInstr) -> bool { + match instr { + // Integer division and remainder trap on divisor == 0 (and i*::MIN / -1 + // for signed division). Must be preserved even when the result is dead. + IrInstr::BinOp { op, .. } => !matches!( + op, + BinOp::I32DivS + | BinOp::I32DivU + | BinOp::I32RemS + | BinOp::I32RemU + | BinOp::I64DivS + | BinOp::I64DivU + | BinOp::I64RemS + | BinOp::I64RemU + ), + // Float-to-integer truncations trap on NaN or out-of-range inputs. + IrInstr::UnOp { op, .. } => !matches!( + op, + UnOp::I32TruncF32S + | UnOp::I32TruncF32U + | UnOp::I32TruncF64S + | UnOp::I32TruncF64U + | UnOp::I64TruncF32S + | UnOp::I64TruncF32U + | UnOp::I64TruncF64S + | UnOp::I64TruncF64U + ), + IrInstr::Const { .. } + | IrInstr::Assign { .. } + | IrInstr::Select { .. } + | IrInstr::GlobalGet { .. } + | IrInstr::MemorySize { .. } => true, + IrInstr::MemoryFill { dst, val, len } => { f(*dst); f(*val); f(*len); } + IrInstr::MemoryInit { dst, src_offset, len, .. } => { f(*dst); f(*src_offset); f(*len); } + IrInstr::DataDrop { .. } => {} + IrInstr::Phi { .. } => unreachable!("Phi nodes must be lowered before optimization"), + _ => false, + } +} + +// ── Rewrite terminator block targets ───────────────────────────────────────── + +/// Rewrite all block-ID references in a terminator from `old` to `new`. +pub fn rewrite_terminator_target(term: &mut IrTerminator, old: BlockId, new: BlockId) { + let replace = |b: &mut BlockId| { + if *b == old { + *b = new; + } + }; + match term { + IrTerminator::Jump { target } => replace(target), + IrTerminator::BranchIf { + if_true, if_false, .. + } => { + replace(if_true); + replace(if_false); + } + IrTerminator::BranchTable { + targets, default, .. + } => { + for t in targets.iter_mut() { + replace(t); + } + replace(default); + } + IrTerminator::Return { .. } | IrTerminator::Unreachable => {} + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BinOp, IrBlock, IrValue, WasmType}; + + #[test] + fn for_each_use_covers_binop() { + let instr = IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }; + let mut uses = vec![]; + for_each_use(&instr, |v| uses.push(v)); + assert_eq!(uses, vec![VarId(0), VarId(1)]); + } + + #[test] + fn instr_dest_returns_none_for_store() { + let instr = IrInstr::Store { + ty: WasmType::I32, + addr: VarId(0), + value: VarId(1), + offset: 0, + width: crate::ir::MemoryAccessWidth::Full, + }; + assert_eq!(instr_dest(&instr), None); + } + + #[test] + fn instr_dest_returns_some_for_const() { + let instr = IrInstr::Const { + dest: VarId(5), + value: IrValue::I32(42), + }; + assert_eq!(instr_dest(&instr), Some(VarId(5))); + } + + #[test] + fn is_side_effect_free_classification() { + assert!(is_side_effect_free(&IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + })); + assert!(is_side_effect_free(&IrInstr::BinOp { + dest: VarId(0), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(2), + })); + assert!(is_side_effect_free(&IrInstr::Assign { + dest: VarId(0), + src: VarId(1), + })); + assert!(!is_side_effect_free(&IrInstr::Store { + ty: WasmType::I32, + addr: VarId(0), + value: VarId(1), + offset: 0, + width: crate::ir::MemoryAccessWidth::Full, + })); + assert!(!is_side_effect_free(&IrInstr::Load { + dest: VarId(0), + ty: WasmType::I32, + addr: VarId(1), + offset: 0, + width: crate::ir::MemoryAccessWidth::Full, + sign: None, + })); + } + + #[test] + fn trapping_binops_not_side_effect_free() { + // Integer div/rem must NOT be classified as side-effect-free because they + // can trap at runtime (division by zero, i*::MIN / -1 for signed div). + for op in [ + BinOp::I32DivS, + BinOp::I32DivU, + BinOp::I32RemS, + BinOp::I32RemU, + BinOp::I64DivS, + BinOp::I64DivU, + BinOp::I64RemS, + BinOp::I64RemU, + ] { + let instr = IrInstr::BinOp { + dest: VarId(0), + op, + lhs: VarId(1), + rhs: VarId(2), + }; + assert!( + !is_side_effect_free(&instr), + "{op:?} should NOT be side-effect-free" + ); + } + // Non-trapping BinOps remain side-effect-free. + assert!(is_side_effect_free(&IrInstr::BinOp { + dest: VarId(0), + op: BinOp::I32Mul, + lhs: VarId(1), + rhs: VarId(2), + })); + } + + #[test] + fn trapping_unops_not_side_effect_free() { + use crate::ir::UnOp; + // Float-to-integer truncations trap on NaN or out-of-range values. + for op in [ + UnOp::I32TruncF32S, + UnOp::I32TruncF32U, + UnOp::I32TruncF64S, + UnOp::I32TruncF64U, + UnOp::I64TruncF32S, + UnOp::I64TruncF32U, + UnOp::I64TruncF64S, + UnOp::I64TruncF64U, + ] { + let instr = IrInstr::UnOp { + dest: VarId(0), + op, + operand: VarId(1), + }; + assert!( + !is_side_effect_free(&instr), + "{op:?} should NOT be side-effect-free" + ); + } + // Non-trapping UnOp remains side-effect-free. + assert!(is_side_effect_free(&IrInstr::UnOp { + dest: VarId(0), + op: UnOp::I32Clz, + operand: VarId(1), + })); + } + + #[test] + fn terminator_successors_coverage() { + assert_eq!( + terminator_successors(&IrTerminator::Return { value: None }), + vec![] + ); + assert_eq!( + terminator_successors(&IrTerminator::Jump { target: BlockId(3) }), + vec![BlockId(3)] + ); + assert_eq!( + terminator_successors(&IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }), + vec![BlockId(1), BlockId(2)] + ); + } + + #[test] + fn build_predecessors_simple() { + let func = IrFunction { + params: vec![], + locals: vec![], + blocks: vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ], + entry_block: BlockId(0), + return_type: None, + type_idx: crate::ir::TypeIdx::new(0), + }; + let preds = build_predecessors(&func); + assert!(preds[&BlockId(0)].is_empty()); + assert_eq!(preds[&BlockId(1)], HashSet::from([BlockId(0)])); + } + + #[test] + fn replace_uses_of_substitutes_correctly() { + let mut instr = IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(0), + }; + replace_uses_of(&mut instr, VarId(0), VarId(5)); + match &instr { + IrInstr::BinOp { lhs, rhs, .. } => { + assert_eq!(*lhs, VarId(5)); + assert_eq!(*rhs, VarId(5)); + } + _ => panic!("expected BinOp"), + } + } + + #[test] + fn rewrite_terminator_target_works() { + let mut term = IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }; + rewrite_terminator_target(&mut term, BlockId(1), BlockId(5)); + match &term { + IrTerminator::BranchIf { + if_true, if_false, .. + } => { + assert_eq!(*if_true, BlockId(5)); + assert_eq!(*if_false, BlockId(2)); + } + _ => panic!("expected BranchIf"), + } + } +} From f49d37e326e53dfdaabb491d3275ac8dfea55002 Mon Sep 17 00:00:00 2001 From: bench Date: Mon, 23 Mar 2026 09:11:44 +0000 Subject: [PATCH 2/4] docs: add PR #12 split status and TODOs Create comprehensive tracking document for the 5-PR split of PR #12: - Summary table of all 5 PRs (A, D, E, F, G) - Detailed breakdown of each PR with files, tests, and review checklist - Known issues and blockers (PR F pattern match errors) - Merge order and dependencies - Testing checklist and known limitations - Quick reference guide and next steps This enables efficient planning and review of the optimizer implementation work. Co-Authored-By: Claude Haiku 4.5 --- PR_SPLIT_STATUS.md | 394 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 394 insertions(+) create mode 100644 PR_SPLIT_STATUS.md diff --git a/PR_SPLIT_STATUS.md b/PR_SPLIT_STATUS.md new file mode 100644 index 0000000..dc34903 --- /dev/null +++ b/PR_SPLIT_STATUS.md @@ -0,0 +1,394 @@ +# PR #12 Split Status — Detailed Plan & TODOs + +## Overview + +This document tracks the progress of splitting PR #12 ("feat: introduce new ir optimizations") into 5 smaller, reviewable PRs. Each PR adds a distinct layer of the optimization pipeline. + +**Original PR #12**: 9,688 additions / 341 deletions across 31 files +**Split Target**: 5 focused PRs, each independently reviewable but stacked for merging + +--- + +## PR Status Summary + +| # | Branch | Title | Status | GitHub PR | Tests | Notes | +|---|--------|-------|--------|-----------|-------|-------| +| A | `pr-a/wasm-float-ops` | Runtime float ops | ✅ Complete | #28 | ✅ Pass | Ready for review | +| D | `pr-d/optimizer-infrastructure` | Optimizer infrastructure | ✅ Complete | #29 | ✅ Pass | Ready for review | +| E | `pr-e/value-optimizations` | Value optimization passes | ✅ Complete | #30 | ✅ Pass | Ready for review | +| F | `pr-f/redundancy-loop-passes` | Redundancy + loop passes | ⚠️ WIP | #31 | ❌ Needs fix | Pattern match errors in utils | +| G | `pr-g/backend-codegen` | Backend + codegen updates | ❌ Not started | — | — | Planned | + +--- + +## Detailed PR Breakdown + +### PR A #28: Runtime Float Operations ✅ + +**Status**: Complete, ready for review + +**Files**: +- `crates/herkos-runtime/src/ops.rs` (+141 lines) + - `wasm_min_f32`, `wasm_max_f32` — NaN propagation, ±0.0 handling + - `wasm_min_f64`, `wasm_max_f64` — Wasm spec compliance + - `wasm_nearest_f32`, `wasm_nearest_f64` — Banker's rounding without libm +- `crates/herkos-runtime/src/lib.rs` (+2): Re-exports +- `.gitignore` (+1): Add `**/*.wasm.rs` + +**Tests**: ✅ All 121 tests pass, clippy clean + +**Key Points**: +- No `std` required, no heap allocation +- Used by `const_prop` pass for compile-time float evaluation +- Zero runtime dependencies + +**Review Checklist**: +- [ ] Verify Wasm spec compliance for min/max/nearest +- [ ] Check NaN and signed-zero handling +- [ ] Confirm no performance regression + +--- + +### PR D #29: Optimizer Infrastructure ✅ + +**Status**: Complete, ready for review + +**Files**: +- `crates/herkos-core/src/optimizer/utils.rs` (+644, new) + - Shared utilities: `terminator_successors`, `build_predecessors` + - Variable traversal: `for_each_use`, `for_each_use_terminator`, `count_uses_of` + - Instruction classification: `instr_dest`, `set_instr_dest`, `is_side_effect_free` + - Variable substitution: `replace_uses_of`, `rewrite_terminator_target` +- `crates/herkos-core/src/optimizer/dead_instrs.rs` (+363, new) + - Dead instruction elimination (post-lowering pass) +- `crates/herkos-core/src/optimizer/empty_blocks.rs` (+333, new) + - Passthrough block elimination +- `crates/herkos-core/src/optimizer/merge_blocks.rs` (+384, new) + - Single-predecessor block merging +- `crates/herkos-core/src/optimizer/dead_blocks.rs` (refactored) + - Now uses shared `terminator_successors` from utils +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Adds `optimize_lowered_ir` function + - Registers post-lowering structural passes + +**Tests**: ✅ All 117 tests pass, clippy clean + +**Passes Registered** (post-lowering only): +- `empty_blocks` → `dead_blocks` → `merge_blocks` → `dead_blocks` → `dead_instrs` (2 iterations) + +**Key Points**: +- Foundation for all subsequent optimizer PRs +- Handles new IR instruction types: `MemoryFill`, `MemoryInit`, `DataDrop` +- Runs structural cleanup in iterations until fixed point + +**Review Checklist**: +- [ ] Verify utility functions are correct (especially `build_predecessors`) +- [ ] Check loop termination (2 iterations sufficient?) +- [ ] Confirm all new IR instructions handled + +--- + +### PR E #30: Value Optimization Passes ✅ + +**Status**: Complete, ready for review + +**Files**: +- `crates/herkos-core/src/optimizer/const_prop.rs` (+1,309, new) + - Constant folding and propagation + - Uses `herkos_runtime` functions for Wasm-spec-compliant evaluation + - Tracks dataflow of constant values +- `crates/herkos-core/src/optimizer/algebraic.rs` (+782, new) + - Algebraic simplifications (e.g., `x * 1 → x`, `x & x → x`) + - Runs after `const_prop` to operate on known constants +- `crates/herkos-core/src/optimizer/copy_prop.rs` (+1,347, new) + - Backward coalescing and forward substitution + - Registered pre-lowering and post-lowering +- `crates/herkos-core/src/ir/types.rs` (updated) + - Added `PartialEq` to `IrValue` (needed for test assertions) +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Registers pre-lowering passes in `optimize_ir` + - Adds `copy_prop` to post-lowering pipeline +- `crates/herkos-core/Cargo.toml` (updated) + - Added `herkos-runtime` dependency (needed for float op evaluation) + +**Tests**: ✅ All 201 tests pass, clippy clean + +**Passes Registered**: +- **Pre-lowering**: `dead_blocks` → `const_prop` → `algebraic` → `copy_prop` +- **Post-lowering**: `copy_prop` + structural passes in iterations + +**Dependencies**: +- ✅ Requires PR A (float ops for const evaluation) +- ✅ Requires PR D (optimizer infrastructure) + +**Key Points**: +- Const prop evaluates Wasm arithmetic using runtime functions +- Pre-lowering const prop simplifies phi nodes before SSA destruction +- Post-lowering copy_prop forwards lower_phis assignments +- IrValue::PartialEq enables test comparisons + +**Review Checklist**: +- [ ] Verify const folding correctness (esp. float ops, overflow handling) +- [ ] Check algebraic simplification rules are sound +- [ ] Confirm copy propagation doesn't break SSA invariants +- [ ] Test with real Wasm modules (fibonacci, etc.) + +--- + +### PR F #31: Redundancy Elimination + Loop Passes ⚠️ WIP + +**Status**: WIP — files copied, needs test fixes + +**Files** (from PR #12, needs adjustment): +- `crates/herkos-core/src/optimizer/local_cse.rs` (+575, new) + - Local common subexpression elimination within blocks +- `crates/herkos-core/src/optimizer/gvn.rs` (+619, new) + - Global value numbering across blocks using dominator tree +- `crates/herkos-core/src/optimizer/licm.rs` (+1,307, new) + - Loop invariant code motion +- `crates/herkos-core/src/optimizer/branch_fold.rs` (+366, new) + - Branch condition simplification +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Registers all 4 passes in `optimize_lowered_ir` pipeline + +**Tests**: ❌ Compile errors — pattern matching incomplete + +**Known Issues**: +1. Missing pattern match arms for `MemoryFill`, `MemoryInit`, `DataDrop` in several functions + - Affects: branch_fold, gvn, licm, local_cse tests + - Fix: Add cases to all `match instr` blocks that pattern match IR instructions +2. Tests reference removed `IrFunction.needs_host` field + - Already fixed for PRs D & E, needs same fix in PR F files + +**Passes Registered** (post-lowering, in iteration loop): +- `local_cse` → `gvn` → `branch_fold` → `licm` (with dead_instrs between) + +**Dependencies**: +- ✅ Requires PR D (optimizer infrastructure + utils) +- ✅ Requires PR E (value optimizations) + +**TODO Before Merge**: +- [ ] Fix missing IR instruction pattern matches (MemoryFill, MemoryInit, DataDrop) +- [ ] Remove remaining `needs_host` references in test code +- [ ] Run full test suite: `cargo test -p herkos-core --lib` +- [ ] Verify dominator tree construction in GVN +- [ ] Check LICM loop detection correctness +- [ ] Test with loop-heavy Wasm (e.g., array operations) + +--- + +### PR G: Backend + Codegen Updates ❌ NOT STARTED + +**Planned Files**: +- `crates/herkos/Cargo.toml` (+1) + - Add `herkos-runtime` dependency to transpiler crate +- `crates/herkos/src/backend/mod.rs` (+13) + - Add trait methods for float operations +- `crates/herkos/src/backend/safe.rs` (+54) + - Implement `wasm_min_f32`, `wasm_max_f32`, etc. in `SafeBackend` +- `crates/herkos/src/codegen/function.rs` (+60) + - Update function code generation +- `crates/herkos/src/codegen/instruction.rs` (+25, -1) + - Add float instruction code generation +- `crates/herkos/src/codegen/mod.rs` (+7, -4) +- `crates/herkos/src/codegen/module.rs` (+4, -1) +- `crates/herkos/tests/e2e.rs` (-3) + - Test cleanup +- `Cargo.lock`: Auto-updated + +**Purpose**: +- Wire runtime float ops into codegen path +- Ensure transpiled Rust code calls the correct runtime functions +- End-to-end integration: Wasm → IR → optimized IR → Rust code + +**Dependencies**: +- Requires PR A (runtime float ops must exist) +- Requires PR D, E, F (optimizers must work) +- No direct code dependency, but all PRs should be merged first + +**TODO**: +- [ ] Extract codegen changes from PR #12 +- [ ] Map Wasm float instructions to backend method calls +- [ ] Update `SafeBackend` implementations +- [ ] Run E2E tests: `cargo test -p herkos-tests` +- [ ] Verify generated Rust code compiles +- [ ] Check performance with benchmarks: `cargo bench -p herkos-tests` + +--- + +## Merge Order & Strategy + +### Strict Linear Order (Recommended) +``` +main + ↓ merge PR A + ├─ main + A + ↓ merge PR D + ├─ main + A + D + ↓ merge PR E + ├─ main + A + D + E + ↓ merge PR F (after fixes) + ├─ main + A + D + E + F + ↓ merge PR G + └─ main + A + D + E + F + G (complete) +``` + +### Rationale +- A is independent (pure runtime addition) +- D is infrastructure foundation +- E depends on A (float ops) + D (utilities) +- F depends on E (builds on value passes) +- G depends on A (must have runtime ops to codegen) + +--- + +## Current Blockers + +### PR F Compilation Errors + +**Error Pattern**: +``` +error[E0004]: non-exhaustive patterns: `MemoryFill`, `MemoryInit`, `DataDrop` +``` + +**Root Cause**: PR #12 was created before `MemoryFill`, `MemoryInit`, `DataDrop` were added to `IrInstr` enum. + +**Solution**: Add pattern match arms to all functions in optimizer files: + +```rust +// Example fix for for_each_use in utils.rs +IrInstr::MemoryFill { dst, val, len } => { + f(*dst); + f(*val); + f(*len); +} +IrInstr::MemoryInit { dst, src_offset, len, .. } => { + f(*dst); + f(*src_offset); + f(*len); +} +IrInstr::DataDrop { .. } => {} +``` + +**Affected Files in PR F**: +- branch_fold.rs (test code) +- gvn.rs (test code) +- licm.rs (test code) +- local_cse.rs (test code) + +**Action Item**: Apply pattern match fixes and re-run tests before final review. + +--- + +## Testing Checklist + +### Unit Tests (Current) +- [x] PR A: `cargo test -p herkos-runtime` ✅ 121 pass +- [x] PR D: `cargo test -p herkos-core --lib` ✅ 117 pass +- [x] PR E: `cargo test -p herkos-core --lib` ✅ 201 pass +- [ ] PR F: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR G: `cargo test` (all crates) ⏳ Not started + +### Integration Tests (After Merge) +- [ ] `cargo test -p herkos-tests` — E2E Wasm → Rust +- [ ] `cargo bench -p herkos-tests` — Fibonacci benchmarks + +### Code Quality (All PRs) +- [x] PR A: `cargo clippy` ✅ clean +- [x] PR D: `cargo clippy` ✅ clean +- [x] PR E: `cargo clippy` ✅ clean +- [ ] PR F: `cargo clippy` ⏳ Not checked (blocked on compile) +- [ ] PR G: `cargo clippy` ⏳ Not started + +### Format Check (All PRs) +- [ ] `cargo fmt --check` on all PRs + +--- + +## Known Limitations & Future Work + +### Not Included (Future Enhancements) +- **Verified backend** — Currently only safe backend implemented +- **Hybrid backend** — Mix of safe and verified code +- **Temporal isolation** — Future feature +- **Contract-based verification** — Future feature +- **`--max-pages` CLI effect** — Not yet wired through transpiler +- **WASI traits** — Standard import traits not yet implemented + +### Performance Notes +- Two-pass structural cleanup may iterate more than necessary +- GVN dominator tree construction could be optimized +- LICM may be conservative in loop detection + +### Documentation +- Consider adding optimizer pass pipeline diagram to SPECIFICATION.md +- Document dataflow analysis assumptions (SSA form requirements) + +--- + +## Quick Reference: File Locations + +``` +Optimizer Passes: +├── crates/herkos-core/src/optimizer/ +│ ├── utils.rs [shared utilities, PR D] +│ ├── dead_blocks.rs [refactored, PR D] +│ ├── dead_instrs.rs [post-lowering, PR D] +│ ├── empty_blocks.rs [post-lowering, PR D] +│ ├── merge_blocks.rs [post-lowering, PR D] +│ ├── const_prop.rs [pre-lowering, PR E] +│ ├── algebraic.rs [pre-lowering, PR E] +│ ├── copy_prop.rs [pre + post, PR E] +│ ├── local_cse.rs [post-lowering, PR F] +│ ├── gvn.rs [post-lowering, PR F] +│ ├── licm.rs [post-lowering, PR F] +│ ├── branch_fold.rs [post-lowering, PR F] +│ └── mod.rs [pipeline coordination] + +Runtime: +├── crates/herkos-runtime/src/ +│ ├── ops.rs [wasm_min/max/nearest, PR A] +│ └── lib.rs [re-exports, PR A] + +Codegen (PR G): +├── crates/herkos/src/ +│ ├── backend/mod.rs [trait updates] +│ ├── backend/safe.rs [float op implementations] +│ └── codegen/ [instruction generation] +``` + +--- + +## Next Steps + +### Immediate (Complete PR F) +1. [ ] Fix pattern matches for MemoryFill/MemoryInit/DataDrop in PR F +2. [ ] Run `cargo test -p herkos-core --lib` and confirm all pass +3. [ ] Run `cargo clippy` and `cargo fmt --check` +4. [ ] Push fixes to `pr-f/redundancy-loop-passes` + +### Short Term (Prepare PR G) +1. [ ] Extract codegen changes from PR #12 +2. [ ] Create `pr-g/backend-codegen` branch +3. [ ] Implement and test +4. [ ] Verify all E2E tests pass + +### Post-Merge (Validation) +1. [ ] Run full integration suite: `cargo test` +2. [ ] Run benchmarks: `cargo bench` +3. [ ] Profile optimization impact (e.g., % code size reduction) +4. [ ] Document lessons learned in SPECIFICATION.md + +--- + +## Contact & References + +- **Original PR**: GitHub #12 +- **Plan file**: `/home/vscode/.claude/plans/streamed-hatching-castle.md` +- **Branch tracking**: See PRs #28–#31 for current status +- **CLAUDE.md**: Project conventions and architecture + +--- + +*Last updated: 2026-03-23* +*Status: 4 of 5 PRs complete, PR F WIP, PR G planned* From b628c1dc9109bb68ea0899be003567d03ae6aaab Mon Sep 17 00:00:00 2001 From: bench Date: Fri, 27 Mar 2026 08:08:13 +0000 Subject: [PATCH 3/4] refactor: reorganize module imports and comment out LICM elimination in optimize_lowered_ir --- crates/herkos-core/src/optimizer/gvn.rs | 149 ++++++++++++++++++------ crates/herkos-core/src/optimizer/mod.rs | 6 +- 2 files changed, 118 insertions(+), 37 deletions(-) diff --git a/crates/herkos-core/src/optimizer/gvn.rs b/crates/herkos-core/src/optimizer/gvn.rs index f5cd732..33a3efd 100644 --- a/crates/herkos-core/src/optimizer/gvn.rs +++ b/crates/herkos-core/src/optimizer/gvn.rs @@ -124,13 +124,23 @@ fn build_multi_def_vars(func: &IrFunction) -> HashSet { /// Compute the reverse-postorder traversal of the CFG starting from `entry`. fn compute_rpo(func: &IrFunction) -> Vec { - let block_idx: HashMap = - func.blocks.iter().enumerate().map(|(i, b)| (b.id, i)).collect(); + let block_idx: HashMap = func + .blocks + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); let mut visited = vec![false; func.blocks.len()]; let mut postorder = Vec::with_capacity(func.blocks.len()); - dfs_postorder(func, func.entry_block, &block_idx, &mut visited, &mut postorder); + dfs_postorder( + func, + func.entry_block, + &block_idx, + &mut visited, + &mut postorder, + ); postorder.reverse(); postorder @@ -164,8 +174,7 @@ fn dfs_postorder( fn compute_idoms(func: &IrFunction) -> HashMap { let rpo = compute_rpo(func); // rpo_num[b] = index in RPO order (entry = 0, smallest index = processed first) - let rpo_num: HashMap = - rpo.iter().enumerate().map(|(i, &b)| (b, i)).collect(); + let rpo_num: HashMap = rpo.iter().enumerate().map(|(i, &b)| (b, i)).collect(); let preds = build_predecessors(func); let entry = func.entry_block; @@ -287,7 +296,9 @@ fn collect_replacements( } } - IrInstr::BinOp { dest, op, lhs, rhs, .. } => { + IrInstr::BinOp { + dest, op, lhs, rhs, .. + } => { // Skip if dest is multiply-defined (same reason as Const). // Also skip if any operand is multiply-defined: a loop phi // var carries different values per iteration, so the same @@ -311,7 +322,10 @@ fn collect_replacements( if multi_def_vars.contains(dest) || multi_def_vars.contains(operand) { continue; } - let key = ValueKey::UnOp { op: *op, operand: *operand }; + let key = ValueKey::UnOp { + op: *op, + operand: *operand, + }; if let Some(&first) = value_map.get(&key) { replacements.insert(*dest, first); } else { @@ -355,8 +369,12 @@ pub fn eliminate(func: &mut IrFunction) { let idom = compute_idoms(func); let dom_children = build_dom_children(&idom, func.entry_block); - let block_idx: HashMap = - func.blocks.iter().enumerate().map(|(i, b)| (b.id, i)).collect(); + let block_idx: HashMap = func + .blocks + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); let multi_def_vars = build_multi_def_vars(func); let mut value_map: HashMap = HashMap::new(); @@ -413,13 +431,21 @@ mod tests { fn cross_block_const_deduplication() { let b0 = IrBlock { id: BlockId(0), - instructions: vec![IrInstr::Const { dest: VarId(0), value: IrValue::I32(42) }], + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], terminator: IrTerminator::Jump { target: BlockId(1) }, }; let b1 = IrBlock { id: BlockId(1), - instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(42) }], - terminator: IrTerminator::Return { value: Some(VarId(1)) }, + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Return { + value: Some(VarId(1)), + }, }; let mut func = make_func(vec![b0, b1]); func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; @@ -427,13 +453,19 @@ mod tests { eliminate(&mut func); assert!( - matches!(func.blocks[0].instructions[0], IrInstr::Const { dest: VarId(0), .. }), + matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + ), "first definition should stay as Const" ); assert!( matches!( func.blocks[1].instructions[0], - IrInstr::Assign { dest: VarId(1), src: VarId(0) } + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } ), "dominated duplicate should become Assign" ); @@ -460,18 +492,26 @@ mod tests { lhs: VarId(0), rhs: VarId(1), }], - terminator: IrTerminator::Return { value: Some(VarId(3)) }, + terminator: IrTerminator::Return { + value: Some(VarId(3)), + }, }; let mut func = make_func(vec![b0, b1]); func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; eliminate(&mut func); - assert!(matches!(func.blocks[0].instructions[0], IrInstr::BinOp { .. })); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); assert!( matches!( func.blocks[1].instructions[0], - IrInstr::Assign { dest: VarId(3), src: VarId(2) } + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } ), "dominated duplicate BinOp should become Assign" ); @@ -493,12 +533,18 @@ mod tests { }; let b1 = IrBlock { id: BlockId(1), - instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(7) }], + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(7), + }], terminator: IrTerminator::Jump { target: BlockId(3) }, }; let b2 = IrBlock { id: BlockId(2), - instructions: vec![IrInstr::Const { dest: VarId(2), value: IrValue::I32(7) }], + instructions: vec![IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(7), + }], terminator: IrTerminator::Jump { target: BlockId(3) }, }; let b3 = IrBlock { @@ -513,11 +559,17 @@ mod tests { // Both consts should remain — neither block dominates the other. assert!( - matches!(func.blocks[1].instructions[0], IrInstr::Const { dest: VarId(1), .. }), + matches!( + func.blocks[1].instructions[0], + IrInstr::Const { dest: VarId(1), .. } + ), "const in B1 must not be eliminated" ); assert!( - matches!(func.blocks[2].instructions[0], IrInstr::Const { dest: VarId(2), .. }), + matches!( + func.blocks[2].instructions[0], + IrInstr::Const { dest: VarId(2), .. } + ), "const in B2 must not be eliminated" ); } @@ -528,7 +580,10 @@ mod tests { // B0 → B1 → B2: const defined in B0, duplicated in B2 let b0 = IrBlock { id: BlockId(0), - instructions: vec![IrInstr::Const { dest: VarId(0), value: IrValue::I32(99) }], + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(99), + }], terminator: IrTerminator::Jump { target: BlockId(1) }, }; let b1 = IrBlock { @@ -538,21 +593,30 @@ mod tests { }; let b2 = IrBlock { id: BlockId(2), - instructions: vec![IrInstr::Const { dest: VarId(1), value: IrValue::I32(99) }], - terminator: IrTerminator::Return { value: Some(VarId(1)) }, + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(99), + }], + terminator: IrTerminator::Return { + value: Some(VarId(1)), + }, }; let mut func = make_func(vec![b0, b1, b2]); func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; eliminate(&mut func); - assert!( - matches!(func.blocks[0].instructions[0], IrInstr::Const { dest: VarId(0), .. }) - ); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + )); assert!( matches!( func.blocks[2].instructions[0], - IrInstr::Assign { dest: VarId(1), src: VarId(0) } + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } ), "deeply dominated duplicate should be eliminated" ); @@ -579,7 +643,9 @@ mod tests { lhs: VarId(1), // swapped rhs: VarId(0), }], - terminator: IrTerminator::Return { value: Some(VarId(3)) }, + terminator: IrTerminator::Return { + value: Some(VarId(3)), + }, }; let mut func = make_func(vec![b0, b1]); func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; @@ -589,7 +655,10 @@ mod tests { assert!( matches!( func.blocks[1].instructions[0], - IrInstr::Assign { dest: VarId(3), src: VarId(2) } + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } ), "commutative cross-block BinOp should be deduplicated" ); @@ -601,8 +670,14 @@ mod tests { let b0 = IrBlock { id: BlockId(0), instructions: vec![ - IrInstr::Const { dest: VarId(0), value: IrValue::I32(1) }, - IrInstr::Const { dest: VarId(1), value: IrValue::I32(1) }, + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(1), + }, ], terminator: IrTerminator::Return { value: None }, }; @@ -612,7 +687,13 @@ mod tests { eliminate(&mut func); // GVN skips single-block functions; duplicates remain (local_cse's job). - assert!(matches!(func.blocks[0].instructions[0], IrInstr::Const { .. })); - assert!(matches!(func.blocks[0].instructions[1], IrInstr::Const { .. })); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { .. } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { .. } + )); } } diff --git a/crates/herkos-core/src/optimizer/mod.rs b/crates/herkos-core/src/optimizer/mod.rs index b27f044..d8decf0 100644 --- a/crates/herkos-core/src/optimizer/mod.rs +++ b/crates/herkos-core/src/optimizer/mod.rs @@ -20,13 +20,13 @@ mod copy_prop; mod dead_blocks; // ── Post-lowering passes ───────────────────────────────────────────────────── +mod branch_fold; mod dead_instrs; mod empty_blocks; -mod merge_blocks; -mod branch_fold; mod gvn; mod licm; mod local_cse; +mod merge_blocks; /// Optimizes the pure SSA IR before phi lowering. /// @@ -70,7 +70,7 @@ pub fn optimize_lowered_ir( dead_instrs::eliminate(func); branch_fold::eliminate(func); dead_instrs::eliminate(func); - licm::eliminate(func); + // licm::eliminate(func); } } } From 26cfd69e1a7bcb864f14c35b0c36bf4217b366e4 Mon Sep 17 00:00:00 2001 From: bench Date: Fri, 27 Mar 2026 08:10:47 +0000 Subject: [PATCH 4/4] refactor: split redundancy and loop passes into separate PRs for better isolation and review --- PR_SPLIT_STATUS.md | 250 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 201 insertions(+), 49 deletions(-) diff --git a/PR_SPLIT_STATUS.md b/PR_SPLIT_STATUS.md index dc34903..9c007d7 100644 --- a/PR_SPLIT_STATUS.md +++ b/PR_SPLIT_STATUS.md @@ -16,7 +16,10 @@ This document tracks the progress of splitting PR #12 ("feat: introduce new ir o | A | `pr-a/wasm-float-ops` | Runtime float ops | ✅ Complete | #28 | ✅ Pass | Ready for review | | D | `pr-d/optimizer-infrastructure` | Optimizer infrastructure | ✅ Complete | #29 | ✅ Pass | Ready for review | | E | `pr-e/value-optimizations` | Value optimization passes | ✅ Complete | #30 | ✅ Pass | Ready for review | -| F | `pr-f/redundancy-loop-passes` | Redundancy + loop passes | ⚠️ WIP | #31 | ❌ Needs fix | Pattern match errors in utils | +| F1 | `pr-f1/branch-fold` | Branch condition folding | ⚠️ WIP | #31 | ❌ Needs fix | Split from old PR F | +| F2 | `pr-f2/local-cse` | Local common subexpression elimination | ❌ Not started | — | — | Split from old PR F | +| F3 | `pr-f3/gvn` | Global value numbering | ❌ Not started | — | — | Split from old PR F | +| F4 | `pr-f4/licm` | Loop invariant code motion | ❌ Not started | — | — | Split from old PR F, last | | G | `pr-g/backend-codegen` | Backend + codegen updates | ❌ Not started | — | — | Planned | --- @@ -135,45 +138,154 @@ This document tracks the progress of splitting PR #12 ("feat: introduce new ir o --- -### PR F #31: Redundancy Elimination + Loop Passes ⚠️ WIP +### PR F1 #31: Branch Condition Folding ⚠️ WIP -**Status**: WIP — files copied, needs test fixes +**Status**: WIP — file exists on current branch, needs pattern match fixes -**Files** (from PR #12, needs adjustment): -- `crates/herkos-core/src/optimizer/local_cse.rs` (+575, new) - - Local common subexpression elimination within blocks -- `crates/herkos-core/src/optimizer/gvn.rs` (+619, new) - - Global value numbering across blocks using dominator tree -- `crates/herkos-core/src/optimizer/licm.rs` (+1,307, new) - - Loop invariant code motion +**Branch**: `pr-f1/branch-fold` (rename/split from `pr-f/redundancy-loop-passes`) + +**Files**: - `crates/herkos-core/src/optimizer/branch_fold.rs` (+366, new) - - Branch condition simplification + - Simplifies branch conditions: constant branches, always-taken jumps - `crates/herkos-core/src/optimizer/mod.rs` (updated) - - Registers all 4 passes in `optimize_lowered_ir` pipeline + - Registers `branch_fold::eliminate` in `optimize_lowered_ir` pipeline -**Tests**: ❌ Compile errors — pattern matching incomplete +**Tests**: ❌ Compile errors — missing pattern match arms **Known Issues**: -1. Missing pattern match arms for `MemoryFill`, `MemoryInit`, `DataDrop` in several functions - - Affects: branch_fold, gvn, licm, local_cse tests - - Fix: Add cases to all `match instr` blocks that pattern match IR instructions -2. Tests reference removed `IrFunction.needs_host` field - - Already fixed for PRs D & E, needs same fix in PR F files +1. Missing pattern match arms for `MemoryFill`, `MemoryInit`, `DataDrop` in test helpers + - Fix: Add exhaustive arms to all `match instr` blocks in branch_fold tests +2. Tests may reference removed `IrFunction.needs_host` field — verify and remove + +**Passes Registered** (post-lowering, added to iteration loop): +- `branch_fold::eliminate` (after `dead_instrs`) + +**Dependencies**: +- ✅ Requires PR D (optimizer infrastructure + utils) +- ✅ Requires PR E (value optimizations) + +**TODO Before Merge**: +- [ ] Create branch `pr-f1/branch-fold` from PR E base +- [ ] Cherry-pick only `branch_fold.rs` + `mod.rs` changes +- [ ] Fix missing IR instruction pattern matches in branch_fold tests +- [ ] Run `cargo test -p herkos-core --lib` +- [ ] Run `cargo clippy` and `cargo fmt --check` + +--- + +### PR F2: Local Common Subexpression Elimination ❌ Not Started + +**Status**: Not started — code exists in `pr-f/redundancy-loop-passes`, needs isolation + +**Branch**: `pr-f2/local-cse` (to be created from PR F1) + +**Files**: +- `crates/herkos-core/src/optimizer/local_cse.rs` (+575, new) + - Eliminates redundant computations within a single basic block + - Keyed on instruction structure (opcode + operands), no cross-block analysis +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Registers `local_cse::eliminate` in `optimize_lowered_ir` pipeline + +**Tests**: ❌ Compile errors (same pattern match issues as F1) + +**Known Issues**: +- Same missing `MemoryFill`, `MemoryInit`, `DataDrop` pattern arms as F1 +- Fix: Add exhaustive arms to all `match instr` blocks in local_cse tests **Passes Registered** (post-lowering, in iteration loop): -- `local_cse` → `gvn` → `branch_fold` → `licm` (with dead_instrs between) +- `local_cse::eliminate` (after structural passes, before GVN) **Dependencies**: - ✅ Requires PR D (optimizer infrastructure + utils) - ✅ Requires PR E (value optimizations) +- ✅ Requires PR F1 (branch fold reduces branches before CSE runs) + +**TODO Before Merge**: +- [ ] Create branch `pr-f2/local-cse` from PR F1 base +- [ ] Cherry-pick only `local_cse.rs` + `mod.rs` changes +- [ ] Fix missing IR instruction pattern matches in local_cse tests +- [ ] Run `cargo test -p herkos-core --lib` +- [ ] Run `cargo clippy` and `cargo fmt --check` + +--- + +### PR F3: Global Value Numbering ❌ Not Started + +**Status**: Not started — code exists in `pr-f/redundancy-loop-passes`, needs isolation + +**Branch**: `pr-f3/gvn` (to be created from PR F2) + +**Files**: +- `crates/herkos-core/src/optimizer/gvn.rs` (+619, new) + - Value numbering across basic blocks using dominator tree traversal + - Eliminates redundant computations that `local_cse` cannot catch across blocks +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Registers `gvn::eliminate` in `optimize_lowered_ir` pipeline + +**Tests**: ❌ Compile errors (same pattern match issues as F1/F2) + +**Known Issues**: +- Same missing `MemoryFill`, `MemoryInit`, `DataDrop` pattern arms +- Dominator tree construction correctness should be carefully verified + +**Passes Registered** (post-lowering, in iteration loop): +- `gvn::eliminate` (after `local_cse`, before `dead_instrs`) + +**Dependencies**: +- ✅ Requires PR D (optimizer infrastructure + utils, including `build_predecessors`) +- ✅ Requires PR E (value optimizations) +- ✅ Requires PR F2 (local CSE runs first, GVN handles cross-block residuals) + +**TODO Before Merge**: +- [ ] Create branch `pr-f3/gvn` from PR F2 base +- [ ] Cherry-pick only `gvn.rs` + `mod.rs` changes +- [ ] Fix missing IR instruction pattern matches in gvn tests +- [ ] Verify dominator tree construction correctness +- [ ] Run `cargo test -p herkos-core --lib` +- [ ] Run `cargo clippy` and `cargo fmt --check` + +--- + +### PR F4: Loop Invariant Code Motion ❌ Not Started + +**Status**: Not started — code exists in `pr-f/redundancy-loop-passes`, needs isolation + +**Branch**: `pr-f4/licm` (to be created from PR F3, **last of the F series**) + +**Files**: +- `crates/herkos-core/src/optimizer/licm.rs` (+1,307, new) + - Detects natural loops via back-edges and dominator tree + - Hoists side-effect-free loop-invariant instructions to loop preheaders +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Uncomments and registers `licm::eliminate` in `optimize_lowered_ir` pipeline + - Currently commented out: `// licm::eliminate(func);` + +**Tests**: ❌ Compile errors (same pattern match issues + licm is currently commented out) + +**Known Issues**: +- Same missing `MemoryFill`, `MemoryInit`, `DataDrop` pattern arms +- Loop detection may be conservative — verify back-edge identification +- LICM must not hoist instructions with side effects (memory writes, traps) + +**Passes Registered** (post-lowering, end of iteration loop): +- `licm::eliminate` (after GVN and `dead_instrs`, final pass in loop body) + +**Dependencies**: +- ✅ Requires PR D (optimizer infrastructure + utils, `is_side_effect_free`) +- ✅ Requires PR E (value optimizations) +- ✅ Requires PR F1 (branch fold first simplifies loop exit conditions) +- ✅ Requires PR F2 (local CSE) +- ✅ Requires PR F3 (GVN — hoisting is most effective after redundancies removed) **TODO Before Merge**: -- [ ] Fix missing IR instruction pattern matches (MemoryFill, MemoryInit, DataDrop) -- [ ] Remove remaining `needs_host` references in test code -- [ ] Run full test suite: `cargo test -p herkos-core --lib` -- [ ] Verify dominator tree construction in GVN -- [ ] Check LICM loop detection correctness +- [ ] Create branch `pr-f4/licm` from PR F3 base +- [ ] Cherry-pick only `licm.rs` + `mod.rs` changes (uncomment licm call) +- [ ] Fix missing IR instruction pattern matches in licm tests +- [ ] Verify loop detection and preheader insertion correctness +- [ ] Check `is_side_effect_free` covers all hoistable instructions - [ ] Test with loop-heavy Wasm (e.g., array operations) +- [ ] Run `cargo test -p herkos-core --lib` +- [ ] Run `cargo clippy` and `cargo fmt --check` --- @@ -227,24 +339,33 @@ main ├─ main + A + D ↓ merge PR E ├─ main + A + D + E - ↓ merge PR F (after fixes) - ├─ main + A + D + E + F + ↓ merge PR F1 (branch fold) + ├─ main + A + D + E + F1 + ↓ merge PR F2 (local CSE) + ├─ main + A + D + E + F1 + F2 + ↓ merge PR F3 (GVN) + ├─ main + A + D + E + F1 + F2 + F3 + ↓ merge PR F4 (LICM — last) + ├─ main + A + D + E + F1 + F2 + F3 + F4 ↓ merge PR G - └─ main + A + D + E + F + G (complete) + └─ main + A + D + E + F1–F4 + G (complete) ``` ### Rationale - A is independent (pure runtime addition) - D is infrastructure foundation - E depends on A (float ops) + D (utilities) -- F depends on E (builds on value passes) +- F1 depends on E — branch fold simplifies control flow for later passes +- F2 depends on F1 — local CSE runs after branches are simplified +- F3 depends on F2 — GVN extends CSE globally across blocks +- F4 depends on F3 — LICM is most effective after redundancies are removed (last pass) - G depends on A (must have runtime ops to codegen) --- ## Current Blockers -### PR F Compilation Errors +### PR F1–F4 Compilation Errors **Error Pattern**: ``` @@ -256,7 +377,7 @@ error[E0004]: non-exhaustive patterns: `MemoryFill`, `MemoryInit`, `DataDrop` **Solution**: Add pattern match arms to all functions in optimizer files: ```rust -// Example fix for for_each_use in utils.rs +// Example fix pattern IrInstr::MemoryFill { dst, val, len } => { f(*dst); f(*val); @@ -270,13 +391,13 @@ IrInstr::MemoryInit { dst, src_offset, len, .. } => { IrInstr::DataDrop { .. } => {} ``` -**Affected Files in PR F**: -- branch_fold.rs (test code) -- gvn.rs (test code) -- licm.rs (test code) -- local_cse.rs (test code) +**Affected files (fix needed in each isolated PR)**: +- `branch_fold.rs` test code → fix in PR F1 +- `local_cse.rs` test code → fix in PR F2 +- `gvn.rs` test code → fix in PR F3 +- `licm.rs` test code → fix in PR F4 -**Action Item**: Apply pattern match fixes and re-run tests before final review. +**Action Item**: Apply pattern match fixes per file when working on each isolated PR branch. --- @@ -286,7 +407,10 @@ IrInstr::DataDrop { .. } => {} - [x] PR A: `cargo test -p herkos-runtime` ✅ 121 pass - [x] PR D: `cargo test -p herkos-core --lib` ✅ 117 pass - [x] PR E: `cargo test -p herkos-core --lib` ✅ 201 pass -- [ ] PR F: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR F1: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR F2: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR F3: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR F4: `cargo test -p herkos-core --lib` ❌ Needs fixes - [ ] PR G: `cargo test` (all crates) ⏳ Not started ### Integration Tests (After Merge) @@ -297,7 +421,10 @@ IrInstr::DataDrop { .. } => {} - [x] PR A: `cargo clippy` ✅ clean - [x] PR D: `cargo clippy` ✅ clean - [x] PR E: `cargo clippy` ✅ clean -- [ ] PR F: `cargo clippy` ⏳ Not checked (blocked on compile) +- [ ] PR F1: `cargo clippy` ⏳ Blocked on compile +- [ ] PR F2: `cargo clippy` ⏳ Not started +- [ ] PR F3: `cargo clippy` ⏳ Not started +- [ ] PR F4: `cargo clippy` ⏳ Not started - [ ] PR G: `cargo clippy` ⏳ Not started ### Format Check (All PRs) @@ -339,10 +466,10 @@ Optimizer Passes: │ ├── const_prop.rs [pre-lowering, PR E] │ ├── algebraic.rs [pre-lowering, PR E] │ ├── copy_prop.rs [pre + post, PR E] -│ ├── local_cse.rs [post-lowering, PR F] -│ ├── gvn.rs [post-lowering, PR F] -│ ├── licm.rs [post-lowering, PR F] -│ ├── branch_fold.rs [post-lowering, PR F] +│ ├── branch_fold.rs [post-lowering, PR F1] +│ ├── local_cse.rs [post-lowering, PR F2] +│ ├── gvn.rs [post-lowering, PR F3] +│ ├── licm.rs [post-lowering, PR F4] │ └── mod.rs [pipeline coordination] Runtime: @@ -361,11 +488,36 @@ Codegen (PR G): ## Next Steps -### Immediate (Complete PR F) -1. [ ] Fix pattern matches for MemoryFill/MemoryInit/DataDrop in PR F -2. [ ] Run `cargo test -p herkos-core --lib` and confirm all pass -3. [ ] Run `cargo clippy` and `cargo fmt --check` -4. [ ] Push fixes to `pr-f/redundancy-loop-passes` +### Immediate (Create PR F1 — branch fold) +1. [ ] Create branch `pr-f1/branch-fold` from PR E tip +2. [ ] Cherry-pick only `branch_fold.rs` + `mod.rs` (branch_fold registration) from old `pr-f/redundancy-loop-passes` +3. [ ] Fix missing `MemoryFill`/`MemoryInit`/`DataDrop` pattern arms in `branch_fold.rs` tests +4. [ ] Run `cargo test -p herkos-core --lib` — confirm all pass +5. [ ] Run `cargo clippy` and `cargo fmt --check` +6. [ ] Push and open PR against PR E branch (or main if E is merged) + +### Next (PR F2 — local CSE) +1. [ ] Create branch `pr-f2/local-cse` from PR F1 tip +2. [ ] Cherry-pick only `local_cse.rs` + `mod.rs` changes +3. [ ] Fix pattern match arms in `local_cse.rs` tests +4. [ ] Run tests, clippy, fmt +5. [ ] Open PR + +### Then (PR F3 — GVN) +1. [ ] Create branch `pr-f3/gvn` from PR F2 tip +2. [ ] Cherry-pick only `gvn.rs` + `mod.rs` changes +3. [ ] Fix pattern match arms in `gvn.rs` tests +4. [ ] Verify dominator tree construction +5. [ ] Run tests, clippy, fmt +6. [ ] Open PR + +### Then (PR F4 — LICM, last) +1. [ ] Create branch `pr-f4/licm` from PR F3 tip +2. [ ] Cherry-pick only `licm.rs` + `mod.rs` changes (uncomment `licm::eliminate`) +3. [ ] Fix pattern match arms in `licm.rs` tests +4. [ ] Verify loop detection and preheader insertion +5. [ ] Run tests, clippy, fmt +6. [ ] Open PR ### Short Term (Prepare PR G) 1. [ ] Extract codegen changes from PR #12 @@ -390,5 +542,5 @@ Codegen (PR G): --- -*Last updated: 2026-03-23* -*Status: 4 of 5 PRs complete, PR F WIP, PR G planned* +*Last updated: 2026-03-27* +*Status: 3 of 8 PRs complete (A, D, E); PR F split into F1–F4 (one pass each, LICM last); PR G planned*