diff --git a/PR_SPLIT_STATUS.md b/PR_SPLIT_STATUS.md new file mode 100644 index 0000000..9c007d7 --- /dev/null +++ b/PR_SPLIT_STATUS.md @@ -0,0 +1,546 @@ +# PR #12 Split Status — Detailed Plan & TODOs + +## Overview + +This document tracks the progress of splitting PR #12 ("feat: introduce new ir optimizations") into 5 smaller, reviewable PRs. Each PR adds a distinct layer of the optimization pipeline. + +**Original PR #12**: 9,688 additions / 341 deletions across 31 files +**Split Target**: 5 focused PRs, each independently reviewable but stacked for merging + +--- + +## PR Status Summary + +| # | Branch | Title | Status | GitHub PR | Tests | Notes | +|---|--------|-------|--------|-----------|-------|-------| +| A | `pr-a/wasm-float-ops` | Runtime float ops | ✅ Complete | #28 | ✅ Pass | Ready for review | +| D | `pr-d/optimizer-infrastructure` | Optimizer infrastructure | ✅ Complete | #29 | ✅ Pass | Ready for review | +| E | `pr-e/value-optimizations` | Value optimization passes | ✅ Complete | #30 | ✅ Pass | Ready for review | +| F1 | `pr-f1/branch-fold` | Branch condition folding | ⚠️ WIP | #31 | ❌ Needs fix | Split from old PR F | +| F2 | `pr-f2/local-cse` | Local common subexpression elimination | ❌ Not started | — | — | Split from old PR F | +| F3 | `pr-f3/gvn` | Global value numbering | ❌ Not started | — | — | Split from old PR F | +| F4 | `pr-f4/licm` | Loop invariant code motion | ❌ Not started | — | — | Split from old PR F, last | +| G | `pr-g/backend-codegen` | Backend + codegen updates | ❌ Not started | — | — | Planned | + +--- + +## Detailed PR Breakdown + +### PR A #28: Runtime Float Operations ✅ + +**Status**: Complete, ready for review + +**Files**: +- `crates/herkos-runtime/src/ops.rs` (+141 lines) + - `wasm_min_f32`, `wasm_max_f32` — NaN propagation, ±0.0 handling + - `wasm_min_f64`, `wasm_max_f64` — Wasm spec compliance + - `wasm_nearest_f32`, `wasm_nearest_f64` — Banker's rounding without libm +- `crates/herkos-runtime/src/lib.rs` (+2): Re-exports +- `.gitignore` (+1): Add `**/*.wasm.rs` + +**Tests**: ✅ All 121 tests pass, clippy clean + +**Key Points**: +- No `std` required, no heap allocation +- Used by `const_prop` pass for compile-time float evaluation +- Zero runtime dependencies + +**Review Checklist**: +- [ ] Verify Wasm spec compliance for min/max/nearest +- [ ] Check NaN and signed-zero handling +- [ ] Confirm no performance regression + +--- + +### PR D #29: Optimizer Infrastructure ✅ + +**Status**: Complete, ready for review + +**Files**: +- `crates/herkos-core/src/optimizer/utils.rs` (+644, new) + - Shared utilities: `terminator_successors`, `build_predecessors` + - Variable traversal: `for_each_use`, `for_each_use_terminator`, `count_uses_of` + - Instruction classification: `instr_dest`, `set_instr_dest`, `is_side_effect_free` + - Variable substitution: `replace_uses_of`, `rewrite_terminator_target` +- `crates/herkos-core/src/optimizer/dead_instrs.rs` (+363, new) + - Dead instruction elimination (post-lowering pass) +- `crates/herkos-core/src/optimizer/empty_blocks.rs` (+333, new) + - Passthrough block elimination +- `crates/herkos-core/src/optimizer/merge_blocks.rs` (+384, new) + - Single-predecessor block merging +- `crates/herkos-core/src/optimizer/dead_blocks.rs` (refactored) + - Now uses shared `terminator_successors` from utils +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Adds `optimize_lowered_ir` function + - Registers post-lowering structural passes + +**Tests**: ✅ All 117 tests pass, clippy clean + +**Passes Registered** (post-lowering only): +- `empty_blocks` → `dead_blocks` → `merge_blocks` → `dead_blocks` → `dead_instrs` (2 iterations) + +**Key Points**: +- Foundation for all subsequent optimizer PRs +- Handles new IR instruction types: `MemoryFill`, `MemoryInit`, `DataDrop` +- Runs structural cleanup in iterations until fixed point + +**Review Checklist**: +- [ ] Verify utility functions are correct (especially `build_predecessors`) +- [ ] Check loop termination (2 iterations sufficient?) +- [ ] Confirm all new IR instructions handled + +--- + +### PR E #30: Value Optimization Passes ✅ + +**Status**: Complete, ready for review + +**Files**: +- `crates/herkos-core/src/optimizer/const_prop.rs` (+1,309, new) + - Constant folding and propagation + - Uses `herkos_runtime` functions for Wasm-spec-compliant evaluation + - Tracks dataflow of constant values +- `crates/herkos-core/src/optimizer/algebraic.rs` (+782, new) + - Algebraic simplifications (e.g., `x * 1 → x`, `x & x → x`) + - Runs after `const_prop` to operate on known constants +- `crates/herkos-core/src/optimizer/copy_prop.rs` (+1,347, new) + - Backward coalescing and forward substitution + - Registered pre-lowering and post-lowering +- `crates/herkos-core/src/ir/types.rs` (updated) + - Added `PartialEq` to `IrValue` (needed for test assertions) +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Registers pre-lowering passes in `optimize_ir` + - Adds `copy_prop` to post-lowering pipeline +- `crates/herkos-core/Cargo.toml` (updated) + - Added `herkos-runtime` dependency (needed for float op evaluation) + +**Tests**: ✅ All 201 tests pass, clippy clean + +**Passes Registered**: +- **Pre-lowering**: `dead_blocks` → `const_prop` → `algebraic` → `copy_prop` +- **Post-lowering**: `copy_prop` + structural passes in iterations + +**Dependencies**: +- ✅ Requires PR A (float ops for const evaluation) +- ✅ Requires PR D (optimizer infrastructure) + +**Key Points**: +- Const prop evaluates Wasm arithmetic using runtime functions +- Pre-lowering const prop simplifies phi nodes before SSA destruction +- Post-lowering copy_prop forwards lower_phis assignments +- IrValue::PartialEq enables test comparisons + +**Review Checklist**: +- [ ] Verify const folding correctness (esp. float ops, overflow handling) +- [ ] Check algebraic simplification rules are sound +- [ ] Confirm copy propagation doesn't break SSA invariants +- [ ] Test with real Wasm modules (fibonacci, etc.) + +--- + +### PR F1 #31: Branch Condition Folding ⚠️ WIP + +**Status**: WIP — file exists on current branch, needs pattern match fixes + +**Branch**: `pr-f1/branch-fold` (rename/split from `pr-f/redundancy-loop-passes`) + +**Files**: +- `crates/herkos-core/src/optimizer/branch_fold.rs` (+366, new) + - Simplifies branch conditions: constant branches, always-taken jumps +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Registers `branch_fold::eliminate` in `optimize_lowered_ir` pipeline + +**Tests**: ❌ Compile errors — missing pattern match arms + +**Known Issues**: +1. Missing pattern match arms for `MemoryFill`, `MemoryInit`, `DataDrop` in test helpers + - Fix: Add exhaustive arms to all `match instr` blocks in branch_fold tests +2. Tests may reference removed `IrFunction.needs_host` field — verify and remove + +**Passes Registered** (post-lowering, added to iteration loop): +- `branch_fold::eliminate` (after `dead_instrs`) + +**Dependencies**: +- ✅ Requires PR D (optimizer infrastructure + utils) +- ✅ Requires PR E (value optimizations) + +**TODO Before Merge**: +- [ ] Create branch `pr-f1/branch-fold` from PR E base +- [ ] Cherry-pick only `branch_fold.rs` + `mod.rs` changes +- [ ] Fix missing IR instruction pattern matches in branch_fold tests +- [ ] Run `cargo test -p herkos-core --lib` +- [ ] Run `cargo clippy` and `cargo fmt --check` + +--- + +### PR F2: Local Common Subexpression Elimination ❌ Not Started + +**Status**: Not started — code exists in `pr-f/redundancy-loop-passes`, needs isolation + +**Branch**: `pr-f2/local-cse` (to be created from PR F1) + +**Files**: +- `crates/herkos-core/src/optimizer/local_cse.rs` (+575, new) + - Eliminates redundant computations within a single basic block + - Keyed on instruction structure (opcode + operands), no cross-block analysis +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Registers `local_cse::eliminate` in `optimize_lowered_ir` pipeline + +**Tests**: ❌ Compile errors (same pattern match issues as F1) + +**Known Issues**: +- Same missing `MemoryFill`, `MemoryInit`, `DataDrop` pattern arms as F1 +- Fix: Add exhaustive arms to all `match instr` blocks in local_cse tests + +**Passes Registered** (post-lowering, in iteration loop): +- `local_cse::eliminate` (after structural passes, before GVN) + +**Dependencies**: +- ✅ Requires PR D (optimizer infrastructure + utils) +- ✅ Requires PR E (value optimizations) +- ✅ Requires PR F1 (branch fold reduces branches before CSE runs) + +**TODO Before Merge**: +- [ ] Create branch `pr-f2/local-cse` from PR F1 base +- [ ] Cherry-pick only `local_cse.rs` + `mod.rs` changes +- [ ] Fix missing IR instruction pattern matches in local_cse tests +- [ ] Run `cargo test -p herkos-core --lib` +- [ ] Run `cargo clippy` and `cargo fmt --check` + +--- + +### PR F3: Global Value Numbering ❌ Not Started + +**Status**: Not started — code exists in `pr-f/redundancy-loop-passes`, needs isolation + +**Branch**: `pr-f3/gvn` (to be created from PR F2) + +**Files**: +- `crates/herkos-core/src/optimizer/gvn.rs` (+619, new) + - Value numbering across basic blocks using dominator tree traversal + - Eliminates redundant computations that `local_cse` cannot catch across blocks +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Registers `gvn::eliminate` in `optimize_lowered_ir` pipeline + +**Tests**: ❌ Compile errors (same pattern match issues as F1/F2) + +**Known Issues**: +- Same missing `MemoryFill`, `MemoryInit`, `DataDrop` pattern arms +- Dominator tree construction correctness should be carefully verified + +**Passes Registered** (post-lowering, in iteration loop): +- `gvn::eliminate` (after `local_cse`, before `dead_instrs`) + +**Dependencies**: +- ✅ Requires PR D (optimizer infrastructure + utils, including `build_predecessors`) +- ✅ Requires PR E (value optimizations) +- ✅ Requires PR F2 (local CSE runs first, GVN handles cross-block residuals) + +**TODO Before Merge**: +- [ ] Create branch `pr-f3/gvn` from PR F2 base +- [ ] Cherry-pick only `gvn.rs` + `mod.rs` changes +- [ ] Fix missing IR instruction pattern matches in gvn tests +- [ ] Verify dominator tree construction correctness +- [ ] Run `cargo test -p herkos-core --lib` +- [ ] Run `cargo clippy` and `cargo fmt --check` + +--- + +### PR F4: Loop Invariant Code Motion ❌ Not Started + +**Status**: Not started — code exists in `pr-f/redundancy-loop-passes`, needs isolation + +**Branch**: `pr-f4/licm` (to be created from PR F3, **last of the F series**) + +**Files**: +- `crates/herkos-core/src/optimizer/licm.rs` (+1,307, new) + - Detects natural loops via back-edges and dominator tree + - Hoists side-effect-free loop-invariant instructions to loop preheaders +- `crates/herkos-core/src/optimizer/mod.rs` (updated) + - Uncomments and registers `licm::eliminate` in `optimize_lowered_ir` pipeline + - Currently commented out: `// licm::eliminate(func);` + +**Tests**: ❌ Compile errors (same pattern match issues + licm is currently commented out) + +**Known Issues**: +- Same missing `MemoryFill`, `MemoryInit`, `DataDrop` pattern arms +- Loop detection may be conservative — verify back-edge identification +- LICM must not hoist instructions with side effects (memory writes, traps) + +**Passes Registered** (post-lowering, end of iteration loop): +- `licm::eliminate` (after GVN and `dead_instrs`, final pass in loop body) + +**Dependencies**: +- ✅ Requires PR D (optimizer infrastructure + utils, `is_side_effect_free`) +- ✅ Requires PR E (value optimizations) +- ✅ Requires PR F1 (branch fold first simplifies loop exit conditions) +- ✅ Requires PR F2 (local CSE) +- ✅ Requires PR F3 (GVN — hoisting is most effective after redundancies removed) + +**TODO Before Merge**: +- [ ] Create branch `pr-f4/licm` from PR F3 base +- [ ] Cherry-pick only `licm.rs` + `mod.rs` changes (uncomment licm call) +- [ ] Fix missing IR instruction pattern matches in licm tests +- [ ] Verify loop detection and preheader insertion correctness +- [ ] Check `is_side_effect_free` covers all hoistable instructions +- [ ] Test with loop-heavy Wasm (e.g., array operations) +- [ ] Run `cargo test -p herkos-core --lib` +- [ ] Run `cargo clippy` and `cargo fmt --check` + +--- + +### PR G: Backend + Codegen Updates ❌ NOT STARTED + +**Planned Files**: +- `crates/herkos/Cargo.toml` (+1) + - Add `herkos-runtime` dependency to transpiler crate +- `crates/herkos/src/backend/mod.rs` (+13) + - Add trait methods for float operations +- `crates/herkos/src/backend/safe.rs` (+54) + - Implement `wasm_min_f32`, `wasm_max_f32`, etc. in `SafeBackend` +- `crates/herkos/src/codegen/function.rs` (+60) + - Update function code generation +- `crates/herkos/src/codegen/instruction.rs` (+25, -1) + - Add float instruction code generation +- `crates/herkos/src/codegen/mod.rs` (+7, -4) +- `crates/herkos/src/codegen/module.rs` (+4, -1) +- `crates/herkos/tests/e2e.rs` (-3) + - Test cleanup +- `Cargo.lock`: Auto-updated + +**Purpose**: +- Wire runtime float ops into codegen path +- Ensure transpiled Rust code calls the correct runtime functions +- End-to-end integration: Wasm → IR → optimized IR → Rust code + +**Dependencies**: +- Requires PR A (runtime float ops must exist) +- Requires PR D, E, F (optimizers must work) +- No direct code dependency, but all PRs should be merged first + +**TODO**: +- [ ] Extract codegen changes from PR #12 +- [ ] Map Wasm float instructions to backend method calls +- [ ] Update `SafeBackend` implementations +- [ ] Run E2E tests: `cargo test -p herkos-tests` +- [ ] Verify generated Rust code compiles +- [ ] Check performance with benchmarks: `cargo bench -p herkos-tests` + +--- + +## Merge Order & Strategy + +### Strict Linear Order (Recommended) +``` +main + ↓ merge PR A + ├─ main + A + ↓ merge PR D + ├─ main + A + D + ↓ merge PR E + ├─ main + A + D + E + ↓ merge PR F1 (branch fold) + ├─ main + A + D + E + F1 + ↓ merge PR F2 (local CSE) + ├─ main + A + D + E + F1 + F2 + ↓ merge PR F3 (GVN) + ├─ main + A + D + E + F1 + F2 + F3 + ↓ merge PR F4 (LICM — last) + ├─ main + A + D + E + F1 + F2 + F3 + F4 + ↓ merge PR G + └─ main + A + D + E + F1–F4 + G (complete) +``` + +### Rationale +- A is independent (pure runtime addition) +- D is infrastructure foundation +- E depends on A (float ops) + D (utilities) +- F1 depends on E — branch fold simplifies control flow for later passes +- F2 depends on F1 — local CSE runs after branches are simplified +- F3 depends on F2 — GVN extends CSE globally across blocks +- F4 depends on F3 — LICM is most effective after redundancies are removed (last pass) +- G depends on A (must have runtime ops to codegen) + +--- + +## Current Blockers + +### PR F1–F4 Compilation Errors + +**Error Pattern**: +``` +error[E0004]: non-exhaustive patterns: `MemoryFill`, `MemoryInit`, `DataDrop` +``` + +**Root Cause**: PR #12 was created before `MemoryFill`, `MemoryInit`, `DataDrop` were added to `IrInstr` enum. + +**Solution**: Add pattern match arms to all functions in optimizer files: + +```rust +// Example fix pattern +IrInstr::MemoryFill { dst, val, len } => { + f(*dst); + f(*val); + f(*len); +} +IrInstr::MemoryInit { dst, src_offset, len, .. } => { + f(*dst); + f(*src_offset); + f(*len); +} +IrInstr::DataDrop { .. } => {} +``` + +**Affected files (fix needed in each isolated PR)**: +- `branch_fold.rs` test code → fix in PR F1 +- `local_cse.rs` test code → fix in PR F2 +- `gvn.rs` test code → fix in PR F3 +- `licm.rs` test code → fix in PR F4 + +**Action Item**: Apply pattern match fixes per file when working on each isolated PR branch. + +--- + +## Testing Checklist + +### Unit Tests (Current) +- [x] PR A: `cargo test -p herkos-runtime` ✅ 121 pass +- [x] PR D: `cargo test -p herkos-core --lib` ✅ 117 pass +- [x] PR E: `cargo test -p herkos-core --lib` ✅ 201 pass +- [ ] PR F1: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR F2: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR F3: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR F4: `cargo test -p herkos-core --lib` ❌ Needs fixes +- [ ] PR G: `cargo test` (all crates) ⏳ Not started + +### Integration Tests (After Merge) +- [ ] `cargo test -p herkos-tests` — E2E Wasm → Rust +- [ ] `cargo bench -p herkos-tests` — Fibonacci benchmarks + +### Code Quality (All PRs) +- [x] PR A: `cargo clippy` ✅ clean +- [x] PR D: `cargo clippy` ✅ clean +- [x] PR E: `cargo clippy` ✅ clean +- [ ] PR F1: `cargo clippy` ⏳ Blocked on compile +- [ ] PR F2: `cargo clippy` ⏳ Not started +- [ ] PR F3: `cargo clippy` ⏳ Not started +- [ ] PR F4: `cargo clippy` ⏳ Not started +- [ ] PR G: `cargo clippy` ⏳ Not started + +### Format Check (All PRs) +- [ ] `cargo fmt --check` on all PRs + +--- + +## Known Limitations & Future Work + +### Not Included (Future Enhancements) +- **Verified backend** — Currently only safe backend implemented +- **Hybrid backend** — Mix of safe and verified code +- **Temporal isolation** — Future feature +- **Contract-based verification** — Future feature +- **`--max-pages` CLI effect** — Not yet wired through transpiler +- **WASI traits** — Standard import traits not yet implemented + +### Performance Notes +- Two-pass structural cleanup may iterate more than necessary +- GVN dominator tree construction could be optimized +- LICM may be conservative in loop detection + +### Documentation +- Consider adding optimizer pass pipeline diagram to SPECIFICATION.md +- Document dataflow analysis assumptions (SSA form requirements) + +--- + +## Quick Reference: File Locations + +``` +Optimizer Passes: +├── crates/herkos-core/src/optimizer/ +│ ├── utils.rs [shared utilities, PR D] +│ ├── dead_blocks.rs [refactored, PR D] +│ ├── dead_instrs.rs [post-lowering, PR D] +│ ├── empty_blocks.rs [post-lowering, PR D] +│ ├── merge_blocks.rs [post-lowering, PR D] +│ ├── const_prop.rs [pre-lowering, PR E] +│ ├── algebraic.rs [pre-lowering, PR E] +│ ├── copy_prop.rs [pre + post, PR E] +│ ├── branch_fold.rs [post-lowering, PR F1] +│ ├── local_cse.rs [post-lowering, PR F2] +│ ├── gvn.rs [post-lowering, PR F3] +│ ├── licm.rs [post-lowering, PR F4] +│ └── mod.rs [pipeline coordination] + +Runtime: +├── crates/herkos-runtime/src/ +│ ├── ops.rs [wasm_min/max/nearest, PR A] +│ └── lib.rs [re-exports, PR A] + +Codegen (PR G): +├── crates/herkos/src/ +│ ├── backend/mod.rs [trait updates] +│ ├── backend/safe.rs [float op implementations] +│ └── codegen/ [instruction generation] +``` + +--- + +## Next Steps + +### Immediate (Create PR F1 — branch fold) +1. [ ] Create branch `pr-f1/branch-fold` from PR E tip +2. [ ] Cherry-pick only `branch_fold.rs` + `mod.rs` (branch_fold registration) from old `pr-f/redundancy-loop-passes` +3. [ ] Fix missing `MemoryFill`/`MemoryInit`/`DataDrop` pattern arms in `branch_fold.rs` tests +4. [ ] Run `cargo test -p herkos-core --lib` — confirm all pass +5. [ ] Run `cargo clippy` and `cargo fmt --check` +6. [ ] Push and open PR against PR E branch (or main if E is merged) + +### Next (PR F2 — local CSE) +1. [ ] Create branch `pr-f2/local-cse` from PR F1 tip +2. [ ] Cherry-pick only `local_cse.rs` + `mod.rs` changes +3. [ ] Fix pattern match arms in `local_cse.rs` tests +4. [ ] Run tests, clippy, fmt +5. [ ] Open PR + +### Then (PR F3 — GVN) +1. [ ] Create branch `pr-f3/gvn` from PR F2 tip +2. [ ] Cherry-pick only `gvn.rs` + `mod.rs` changes +3. [ ] Fix pattern match arms in `gvn.rs` tests +4. [ ] Verify dominator tree construction +5. [ ] Run tests, clippy, fmt +6. [ ] Open PR + +### Then (PR F4 — LICM, last) +1. [ ] Create branch `pr-f4/licm` from PR F3 tip +2. [ ] Cherry-pick only `licm.rs` + `mod.rs` changes (uncomment `licm::eliminate`) +3. [ ] Fix pattern match arms in `licm.rs` tests +4. [ ] Verify loop detection and preheader insertion +5. [ ] Run tests, clippy, fmt +6. [ ] Open PR + +### Short Term (Prepare PR G) +1. [ ] Extract codegen changes from PR #12 +2. [ ] Create `pr-g/backend-codegen` branch +3. [ ] Implement and test +4. [ ] Verify all E2E tests pass + +### Post-Merge (Validation) +1. [ ] Run full integration suite: `cargo test` +2. [ ] Run benchmarks: `cargo bench` +3. [ ] Profile optimization impact (e.g., % code size reduction) +4. [ ] Document lessons learned in SPECIFICATION.md + +--- + +## Contact & References + +- **Original PR**: GitHub #12 +- **Plan file**: `/home/vscode/.claude/plans/streamed-hatching-castle.md` +- **Branch tracking**: See PRs #28–#31 for current status +- **CLAUDE.md**: Project conventions and architecture + +--- + +*Last updated: 2026-03-27* +*Status: 3 of 8 PRs complete (A, D, E); PR F split into F1–F4 (one pass each, LICM last); PR G planned* diff --git a/crates/herkos-core/src/ir/types.rs b/crates/herkos-core/src/ir/types.rs index 09407bc..4bcda3d 100644 --- a/crates/herkos-core/src/ir/types.rs +++ b/crates/herkos-core/src/ir/types.rs @@ -447,7 +447,7 @@ impl fmt::Display for IrValue { } /// Binary operations. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum BinOp { // i32 operations I32Add, @@ -543,7 +543,7 @@ pub enum BinOp { } /// Unary operations. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum UnOp { // i32 unary I32Clz, // Count leading zeros diff --git a/crates/herkos-core/src/optimizer/branch_fold.rs b/crates/herkos-core/src/optimizer/branch_fold.rs new file mode 100644 index 0000000..3a7f978 --- /dev/null +++ b/crates/herkos-core/src/optimizer/branch_fold.rs @@ -0,0 +1,365 @@ +//! Branch condition folding. +//! +//! Simplifies `BranchIf` terminators by looking at the instruction that +//! defines the condition variable: +//! +//! - `Eqz(x)` as condition → swap branch targets, use `x` directly +//! - `Ne(x, 0)` as condition → use `x` directly +//! - `Eq(x, 0)` as condition → swap branch targets, use `x` directly +//! +//! After substitution, the defining instruction becomes dead (single use was +//! the branch) and is cleaned up by `dead_instrs`. + +use super::utils::{build_global_use_count, instr_dest}; +use crate::ir::{BinOp, IrFunction, IrInstr, IrTerminator, IrValue, UnOp, VarId}; +use std::collections::HashMap; + +pub fn eliminate(func: &mut IrFunction) { + loop { + let global_uses = build_global_use_count(func); + if !fold_one(func, &global_uses) { + break; + } + } +} + +/// Attempt a single branch fold across the function. Returns `true` if a +/// change was made. +fn fold_one(func: &mut IrFunction, global_uses: &HashMap) -> bool { + // Build a map of VarId → defining instruction info. + // We only care about single-use vars defined by Eqz, Ne(x,0), or Eq(x,0). + let mut var_defs: HashMap = HashMap::new(); + + // Also build a global constant map for checking if an operand is zero. + let global_consts = build_global_const_map(func); + + for block in &func.blocks { + let mut local_consts = global_consts.clone(); + for instr in &block.instructions { + if let IrInstr::Const { dest, value } = instr { + local_consts.insert(*dest, *value); + } + + if let Some(dest) = instr_dest(instr) { + match instr { + IrInstr::UnOp { + op: UnOp::I32Eqz | UnOp::I64Eqz, + operand, + .. + } => { + var_defs.insert(dest, VarDef::Eqz(*operand)); + } + IrInstr::BinOp { + op: BinOp::I32Ne | BinOp::I64Ne, + lhs, + rhs, + .. + } => { + if is_zero(rhs, &local_consts) { + var_defs.insert(dest, VarDef::NeZero(*lhs)); + } else if is_zero(lhs, &local_consts) { + var_defs.insert(dest, VarDef::NeZero(*rhs)); + } + } + IrInstr::BinOp { + op: BinOp::I32Eq | BinOp::I64Eq, + lhs, + rhs, + .. + } => { + if is_zero(rhs, &local_consts) { + var_defs.insert(dest, VarDef::EqZero(*lhs)); + } else if is_zero(lhs, &local_consts) { + var_defs.insert(dest, VarDef::EqZero(*rhs)); + } + } + _ => {} + } + } + } + } + + // Now scan terminators for BranchIf with a foldable condition. + for block in &mut func.blocks { + let condition = match &block.terminator { + IrTerminator::BranchIf { condition, .. } => *condition, + _ => continue, + }; + + // Only fold if the condition has exactly one use (the BranchIf). + if global_uses.get(&condition).copied().unwrap_or(0) != 1 { + continue; + } + + let def = match var_defs.get(&condition) { + Some(d) => d, + None => continue, + }; + + match def { + VarDef::Eqz(inner) | VarDef::EqZero(inner) => { + // eqz(x) != 0 ≡ x == 0, so swap targets and use x + if let IrTerminator::BranchIf { + condition: cond, + if_true, + if_false, + } = &mut block.terminator + { + *cond = *inner; + std::mem::swap(if_true, if_false); + } + return true; + } + VarDef::NeZero(inner) => { + // ne(x, 0) != 0 ≡ x != 0, so just use x + if let IrTerminator::BranchIf { + condition: cond, .. + } = &mut block.terminator + { + *cond = *inner; + } + return true; + } + } + } + + false +} + +#[derive(Clone, Copy)] +enum VarDef { + Eqz(VarId), + NeZero(VarId), + EqZero(VarId), +} + +fn is_zero(var: &VarId, consts: &HashMap) -> bool { + matches!( + consts.get(var), + Some(IrValue::I32(0)) | Some(IrValue::I64(0)) + ) +} + +fn build_global_const_map(func: &IrFunction) -> HashMap { + let mut total_defs: HashMap = HashMap::new(); + let mut const_defs: HashMap = HashMap::new(); + + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *total_defs.entry(dest).or_insert(0) += 1; + if let IrInstr::Const { dest, value } = instr { + const_defs.insert(*dest, *value); + } + } + } + } + + const_defs + .into_iter() + .filter(|(v, _)| total_defs.get(v).copied().unwrap_or(0) == 1) + .collect() +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, TypeIdx}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + #[test] + fn eqz_swaps_targets() { + // v1 = Eqz(v0); BranchIf(v1, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn ne_zero_simplifies() { + // v1 = 0; v2 = Ne(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B1, B2) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(1), "targets should NOT be swapped"); + assert_eq!(*if_false, BlockId(2)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn eq_zero_swaps() { + // v1 = 0; v2 = Eq(v0, v1); BranchIf(v2, B1, B2) → BranchIf(v0, B2, B1) + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Eq, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }]); + eliminate(&mut func); + match &func.blocks[0].terminator { + IrTerminator::BranchIf { + condition, + if_true, + if_false, + } => { + assert_eq!(*condition, VarId(0)); + assert_eq!(*if_true, BlockId(2), "targets should be swapped"); + assert_eq!(*if_false, BlockId(1)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn multi_use_not_folded() { + // v1 = Eqz(v0); use(v1) elsewhere → don't fold + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Eqz, + operand: VarId(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(1)), // second use of v1 + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + eliminate(&mut func); + // Should NOT fold — v1 has 2 uses + match &func.blocks[0].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(1), "should not have been folded"); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } + + #[test] + fn cross_block_zero_const() { + // B0: v1 = 0; Jump(B1) + // B1: v2 = Ne(v0, v1); BranchIf(v2, B2, B3) → BranchIf(v0, B2, B3) + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Ne, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + ]); + eliminate(&mut func); + match &func.blocks[1].terminator { + IrTerminator::BranchIf { condition, .. } => { + assert_eq!(*condition, VarId(0)); + } + other => panic!("expected BranchIf, got {other:?}"), + } + } +} diff --git a/crates/herkos-core/src/optimizer/gvn.rs b/crates/herkos-core/src/optimizer/gvn.rs new file mode 100644 index 0000000..33a3efd --- /dev/null +++ b/crates/herkos-core/src/optimizer/gvn.rs @@ -0,0 +1,699 @@ +//! Global value numbering (GVN) — cross-block CSE using the dominator tree. +//! +//! Extends block-local CSE ([`super::local_cse`]) to work across basic blocks. +//! If block A dominates block B (every path to B passes through A), then any +//! pure computation defined in A with the same value key as one in B can be +//! reused in B instead of recomputing. +//! +//! ## Algorithm +//! +//! 1. Compute the immediate dominator of each block (Cooper/Harvey/Kennedy +//! iterative algorithm) to build the dominator tree. +//! 2. Walk the dominator tree in preorder using a scoped value-number table. +//! On entry to a block, push a new scope; on exit, pop it. +//! 3. For each pure instruction (`Const`, `BinOp`, `UnOp`) in the current +//! block, compute a value key. If the key already exists in any enclosing +//! scope (meaning it was computed in a dominating block), record a +//! replacement: `dest → first_var`. Otherwise insert the key into the +//! current scope. +//! 4. After the walk, rewrite all recorded destinations to +//! `Assign { dest, src: first_var }` and let copy-propagation clean up. +//! +//! **Only pure instructions are eligible.** Loads, calls, and memory ops are +//! never deduplicated (they may trap or have observable side effects). + +use super::utils::{build_predecessors, instr_dest, prune_dead_locals, terminator_successors}; +use crate::ir::{BinOp, BlockId, IrFunction, IrInstr, IrValue, UnOp, VarId}; +use std::collections::{HashMap, HashSet}; + +// ── Value key ──────────────────────────────────────────────────────────────── + +/// Hashable representation of a pure computation for deduplication. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ValueKey { + Const(ConstKey), + BinOp { op: BinOp, lhs: VarId, rhs: VarId }, + UnOp { op: UnOp, operand: VarId }, +} + +/// Bit-level constant key that implements `Eq`/`Hash` correctly for floats. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ConstKey { + I32(i32), + I64(i64), + F32(u32), + F64(u64), +} + +impl From for ConstKey { + fn from(v: IrValue) -> Self { + match v { + IrValue::I32(x) => ConstKey::I32(x), + IrValue::I64(x) => ConstKey::I64(x), + IrValue::F32(x) => ConstKey::F32(x.to_bits()), + IrValue::F64(x) => ConstKey::F64(x.to_bits()), + } + } +} + +fn is_commutative(op: &BinOp) -> bool { + matches!( + op, + BinOp::I32Add + | BinOp::I32Mul + | BinOp::I32And + | BinOp::I32Or + | BinOp::I32Xor + | BinOp::I32Eq + | BinOp::I32Ne + | BinOp::I64Add + | BinOp::I64Mul + | BinOp::I64And + | BinOp::I64Or + | BinOp::I64Xor + | BinOp::I64Eq + | BinOp::I64Ne + | BinOp::F32Add + | BinOp::F32Mul + | BinOp::F32Eq + | BinOp::F32Ne + | BinOp::F64Add + | BinOp::F64Mul + | BinOp::F64Eq + | BinOp::F64Ne + ) +} + +fn binop_key(op: BinOp, lhs: VarId, rhs: VarId) -> ValueKey { + let (lhs, rhs) = if is_commutative(&op) && lhs.0 > rhs.0 { + (rhs, lhs) + } else { + (lhs, rhs) + }; + ValueKey::BinOp { op, lhs, rhs } +} + +// ── Multi-definition detection ─────────────────────────────────────────────── + +/// Build the set of variables defined more than once across the function. +/// +/// After phi lowering the code is no longer in strict SSA form: loop phi +/// variables receive an initial assignment in the pre-loop block and a +/// back-edge update at the end of each iteration. These variables carry +/// different values at different program points, so any BinOp/UnOp that uses +/// them cannot be safely hoisted or deduplicated across blocks. +/// +/// `Const` instructions are always safe (they have no operands). +fn build_multi_def_vars(func: &IrFunction) -> HashSet { + let mut def_count: HashMap = HashMap::new(); + for block in &func.blocks { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + *def_count.entry(dest).or_insert(0) += 1; + } + } + } + def_count + .into_iter() + .filter(|&(_, count)| count > 1) + .map(|(v, _)| v) + .collect() +} + +// ── Dominator tree ─────────────────────────────────────────────────────────── + +/// Compute the reverse-postorder traversal of the CFG starting from `entry`. +fn compute_rpo(func: &IrFunction) -> Vec { + let block_idx: HashMap = func + .blocks + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + let mut visited = vec![false; func.blocks.len()]; + let mut postorder = Vec::with_capacity(func.blocks.len()); + + dfs_postorder( + func, + func.entry_block, + &block_idx, + &mut visited, + &mut postorder, + ); + + postorder.reverse(); + postorder +} + +fn dfs_postorder( + func: &IrFunction, + block_id: BlockId, + block_idx: &HashMap, + visited: &mut Vec, + postorder: &mut Vec, +) { + let idx = match block_idx.get(&block_id) { + Some(&i) => i, + None => return, + }; + if visited[idx] { + return; + } + visited[idx] = true; + + for succ in terminator_successors(&func.blocks[idx].terminator) { + dfs_postorder(func, succ, block_idx, visited, postorder); + } + postorder.push(block_id); +} + +/// Compute the immediate dominator of each block using Cooper/Harvey/Kennedy. +/// +/// Returns `idom[b] = immediate dominator of b`, with `idom[entry] = entry`. +fn compute_idoms(func: &IrFunction) -> HashMap { + let rpo = compute_rpo(func); + // rpo_num[b] = index in RPO order (entry = 0, smallest index = processed first) + let rpo_num: HashMap = rpo.iter().enumerate().map(|(i, &b)| (b, i)).collect(); + + let preds = build_predecessors(func); + let entry = func.entry_block; + + let mut idom: HashMap = HashMap::new(); + idom.insert(entry, entry); + + let mut changed = true; + while changed { + changed = false; + // Process blocks in RPO order, skipping the entry. + for &b in rpo.iter().skip(1) { + let block_preds = &preds[&b]; + + // Start with the first predecessor that already has an idom assigned. + let mut new_idom = match block_preds + .iter() + .filter(|&&p| idom.contains_key(&p)) + .min_by_key(|&&p| rpo_num[&p]) + { + Some(&p) => p, + None => continue, // unreachable block — skip + }; + + // Intersect (walk up dom tree) with all other processed predecessors. + for &p in block_preds { + if p != new_idom && idom.contains_key(&p) { + new_idom = intersect(p, new_idom, &idom, &rpo_num); + } + } + + if idom.get(&b) != Some(&new_idom) { + idom.insert(b, new_idom); + changed = true; + } + } + } + + idom +} + +/// Walk up both fingers until they meet — the standard Cooper intersect. +fn intersect( + mut a: BlockId, + mut b: BlockId, + idom: &HashMap, + rpo_num: &HashMap, +) -> BlockId { + while a != b { + while rpo_num[&a] > rpo_num[&b] { + a = idom[&a]; + } + while rpo_num[&b] > rpo_num[&a] { + b = idom[&b]; + } + } + a +} + +/// Build dominator-tree children from the `idom` map. +fn build_dom_children( + idom: &HashMap, + entry: BlockId, +) -> HashMap> { + let mut children: HashMap> = HashMap::new(); + for (&b, &d) in idom { + if b != entry { + children.entry(d).or_default().push(b); + } + } + // Sort children for deterministic output. + for v in children.values_mut() { + v.sort_unstable_by_key(|id| id.0); + } + children +} + +// ── GVN walk ───────────────────────────────────────────────────────────────── + +/// Recursively walk the dominator tree in preorder. +/// +/// `value_map` is a flat map that acts as a scoped table: on entry we insert +/// new keys (recording them in `frame_keys`), on exit we remove them, restoring +/// the parent scope. Any key already present in `value_map` when we visit a +/// block was computed in a dominating block — safe to reuse. +fn collect_replacements( + func: &IrFunction, + block_id: BlockId, + dom_children: &HashMap>, + block_idx: &HashMap, + multi_def_vars: &HashSet, + value_map: &mut HashMap, + replacements: &mut HashMap, +) { + let idx = match block_idx.get(&block_id) { + Some(&i) => i, + None => return, + }; + + let mut frame_keys: Vec = Vec::new(); + + for instr in &func.blocks[idx].instructions { + match instr { + IrInstr::Const { dest, value } => { + // A multiply-defined dest (loop phi var) must be skipped + // entirely: adding it to replacements would replace ALL of + // its definitions with Assign(first), clobbering back-edge + // updates; inserting it into value_map would let dominated + // blocks wrongly reuse a value that changes each iteration. + if multi_def_vars.contains(dest) { + continue; + } + let key = ValueKey::Const(ConstKey::from(*value)); + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + IrInstr::BinOp { + dest, op, lhs, rhs, .. + } => { + // Skip if dest is multiply-defined (same reason as Const). + // Also skip if any operand is multiply-defined: a loop phi + // var carries different values per iteration, so the same + // BinOp in two dominated blocks can produce different results. + if multi_def_vars.contains(dest) + || multi_def_vars.contains(lhs) + || multi_def_vars.contains(rhs) + { + continue; + } + let key = binop_key(*op, *lhs, *rhs); + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + IrInstr::UnOp { dest, op, operand } => { + if multi_def_vars.contains(dest) || multi_def_vars.contains(operand) { + continue; + } + let key = ValueKey::UnOp { + op: *op, + operand: *operand, + }; + if let Some(&first) = value_map.get(&key) { + replacements.insert(*dest, first); + } else { + value_map.insert(key.clone(), *dest); + frame_keys.push(key); + } + } + + _ => {} + } + } + + // Recurse into dominated children. + if let Some(children) = dom_children.get(&block_id) { + for &child in children { + collect_replacements( + func, + child, + dom_children, + block_idx, + multi_def_vars, + value_map, + replacements, + ); + } + } + + // Pop this block's scope. + for key in frame_keys { + value_map.remove(&key); + } +} + +// ── Pass entry point ───────────────────────────────────────────────────────── + +/// Eliminates common subexpressions across basic blocks using the dominator tree. +pub fn eliminate(func: &mut IrFunction) { + if func.blocks.len() < 2 { + return; // nothing to do for single-block functions (local_cse covers those) + } + + let idom = compute_idoms(func); + let dom_children = build_dom_children(&idom, func.entry_block); + let block_idx: HashMap = func + .blocks + .iter() + .enumerate() + .map(|(i, b)| (b.id, i)) + .collect(); + + let multi_def_vars = build_multi_def_vars(func); + let mut value_map: HashMap = HashMap::new(); + let mut replacements: HashMap = HashMap::new(); + + collect_replacements( + func, + func.entry_block, + &dom_children, + &block_idx, + &multi_def_vars, + &mut value_map, + &mut replacements, + ); + + if replacements.is_empty() { + return; + } + + for block in &mut func.blocks { + for instr in &mut block.instructions { + if let Some(dest) = instr_dest(instr) { + if let Some(&src) = replacements.get(&dest) { + *instr = IrInstr::Assign { dest, src }; + } + } + } + } + + prune_dead_locals(func); +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{IrBlock, IrTerminator, IrValue, TypeIdx, WasmType}; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + /// Entry (B0) → B1: const duplicated across the edge. + /// B0 dominates B1, so the duplicate in B1 should be replaced with Assign. + #[test] + fn cross_block_const_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Return { + value: Some(VarId(1)), + }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + ), + "first definition should stay as Const" + ); + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + ), + "dominated duplicate should become Assign" + ); + } + + /// Entry (B0) → B1: BinOp duplicated across the edge. + #[test] + fn cross_block_binop_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Return { + value: Some(VarId(3)), + }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "dominated duplicate BinOp should become Assign" + ); + } + + /// B0 branches to B1 and B2 (diamond). B1 and B2 don't dominate each other, + /// so a const in B1 should NOT eliminate the same const in B2. + #[test] + fn sibling_blocks_not_deduplicated() { + // B0 → B1, B0 → B2, both converge to B3 + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(7), + }], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }; + let b2 = IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(7), + }], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }; + let b3 = IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }; + let mut func = make_func(vec![b0, b1, b2, b3]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + + eliminate(&mut func); + + // Both consts should remain — neither block dominates the other. + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Const { dest: VarId(1), .. } + ), + "const in B1 must not be eliminated" + ); + assert!( + matches!( + func.blocks[2].instructions[0], + IrInstr::Const { dest: VarId(2), .. } + ), + "const in B2 must not be eliminated" + ); + } + + /// A const defined in B0 (entry) should be reused in a deeply dominated block. + #[test] + fn deep_domination_chain() { + // B0 → B1 → B2: const defined in B0, duplicated in B2 + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(99), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }; + let b2 = IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(99), + }], + terminator: IrTerminator::Return { + value: Some(VarId(1)), + }, + }; + let mut func = make_func(vec![b0, b1, b2]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + )); + assert!( + matches!( + func.blocks[2].instructions[0], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + ), + "deeply dominated duplicate should be eliminated" + ); + } + + /// Commutative BinOps with swapped operands in a dominated block should be deduped. + #[test] + fn cross_block_commutative_deduplication() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let b1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Mul, + lhs: VarId(1), // swapped + rhs: VarId(0), + }], + terminator: IrTerminator::Return { + value: Some(VarId(3)), + }, + }; + let mut func = make_func(vec![b0, b1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + + eliminate(&mut func); + + assert!( + matches!( + func.blocks[1].instructions[0], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "commutative cross-block BinOp should be deduplicated" + ); + } + + /// Single-block functions are skipped entirely (handled by local_cse). + #[test] + fn single_block_function_unchanged() { + let b0 = IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(1), + }, + ], + terminator: IrTerminator::Return { value: None }, + }; + let mut func = make_func(vec![b0]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + + eliminate(&mut func); + + // GVN skips single-block functions; duplicates remain (local_cse's job). + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { .. } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { .. } + )); + } +} diff --git a/crates/herkos-core/src/optimizer/licm.rs b/crates/herkos-core/src/optimizer/licm.rs new file mode 100644 index 0000000..4def516 --- /dev/null +++ b/crates/herkos-core/src/optimizer/licm.rs @@ -0,0 +1,1306 @@ +//! Loop-invariant code motion (LICM). +//! +//! Identifies instructions in loop headers whose operands don't change across +//! iterations, and moves them to a preheader block. +//! +//! ## Algorithm +//! +//! 1. Compute dominators (iterative algorithm) +//! 2. Find back edges: edge (src → tgt) where tgt dominates src +//! 3. Find natural loops: for each back edge, collect all blocks that reach +//! the source without going through the header +//! 4. For each loop, identify invariant instructions in the header (fixpoint): +//! - `Const` — trivially invariant +//! - `BinOp`, `UnOp`, `Assign`, `Select` — invariant if all operands are +//! defined outside the loop or by other invariant instructions +//! - Skip: `Load`, `Store`, `Call*`, `Global*`, `Memory*` +//! 5. Create or reuse a preheader block and move invariant instructions there +//! +//! **V1 simplification:** only hoists from the loop header block (which +//! dominates all loop blocks by definition). + +use super::utils::{ + build_predecessors, for_each_use, instr_dest, rewrite_terminator_target, terminator_successors, +}; +use crate::ir::{BlockId, IrBlock, IrFunction, IrInstr, IrTerminator, VarId}; +use std::collections::{HashMap, HashSet}; + +/// Run loop-invariant code motion on `func`. +pub fn eliminate(func: &mut IrFunction) { + if func.blocks.len() < 2 { + return; + } + + let preds = build_predecessors(func); + let dominators = compute_dominators(func, &preds); + let back_edges = find_back_edges(func, &dominators); + + if back_edges.is_empty() { + return; + } + + let loops = find_natural_loops(&back_edges, &preds); + + for (header, loop_blocks) in &loops { + hoist_invariants(func, *header, loop_blocks); + } +} + +// ── Dominator computation ──────────────────────────────────────────────────── + +/// Compute the dominator set for each block using the iterative algorithm. +/// +/// Returns a map from each block to the set of blocks that dominate it. +fn compute_dominators( + func: &IrFunction, + preds: &HashMap>, +) -> HashMap> { + let entry = func.entry_block; + let all_block_ids: HashSet = func.blocks.iter().map(|b| b.id).collect(); + + let mut dom: HashMap> = HashMap::new(); + dom.insert(entry, HashSet::from([entry])); + + for block in &func.blocks { + if block.id != entry { + dom.insert(block.id, all_block_ids.clone()); + } + } + + loop { + let mut changed = false; + for block in &func.blocks { + if block.id == entry { + continue; + } + let pred_set = &preds[&block.id]; + if pred_set.is_empty() { + continue; + } + + // new_dom = {self} ∪ ∩(dom[p] for p in preds) + let mut new_dom: Option> = None; + for p in pred_set { + if let Some(p_dom) = dom.get(p) { + new_dom = Some(match new_dom { + None => p_dom.clone(), + Some(current) => current.intersection(p_dom).copied().collect(), + }); + } + } + + let mut new_dom = new_dom.unwrap_or_default(); + new_dom.insert(block.id); + + if new_dom != dom[&block.id] { + dom.insert(block.id, new_dom); + changed = true; + } + } + if !changed { + break; + } + } + + dom +} + +// ── Back edge detection ────────────────────────────────────────────────────── + +/// Find all back edges in the CFG. +/// +/// A back edge is (src, tgt) where tgt dominates src. +fn find_back_edges( + func: &IrFunction, + dominators: &HashMap>, +) -> Vec<(BlockId, BlockId)> { + let mut back_edges = Vec::new(); + for block in &func.blocks { + for succ in terminator_successors(&block.terminator) { + if dominators + .get(&block.id) + .is_some_and(|dom_set| dom_set.contains(&succ)) + { + back_edges.push((block.id, succ)); + } + } + } + back_edges +} + +// ── Natural loop detection ─────────────────────────────────────────────────── + +/// Find natural loops from back edges. +/// +/// For each back edge (src → header), collects all blocks that can reach `src` +/// without going through `header`. Multiple back edges with the same header +/// are merged into one loop. +fn find_natural_loops( + back_edges: &[(BlockId, BlockId)], + preds: &HashMap>, +) -> Vec<(BlockId, HashSet)> { + let mut loops: HashMap> = HashMap::new(); + + for &(src, header) in back_edges { + let loop_blocks = loops.entry(header).or_insert_with(|| { + let mut set = HashSet::new(); + set.insert(header); + set + }); + + let mut worklist = vec![src]; + while let Some(n) = worklist.pop() { + if loop_blocks.insert(n) { + if let Some(n_preds) = preds.get(&n) { + for &p in n_preds { + worklist.push(p); + } + } + } + } + } + + loops.into_iter().collect() +} + +// ── Invariant identification & hoisting ────────────────────────────────────── + +/// Returns `true` if the instruction type is eligible for LICM hoisting. +/// +/// Only pure, side-effect-free computations are hoistable. Instructions that +/// depend on mutable state (`Global*`, `Memory*`) or have side effects +/// (`Load`, `Store`, `Call*`) are excluded. +fn is_licm_hoistable(instr: &IrInstr) -> bool { + matches!( + instr, + IrInstr::Const { .. } + | IrInstr::BinOp { .. } + | IrInstr::UnOp { .. } + | IrInstr::Assign { .. } + | IrInstr::Select { .. } + ) +} + +/// Identify loop-invariant instructions in the header and hoist them to a preheader. +fn hoist_invariants(func: &mut IrFunction, header: BlockId, loop_blocks: &HashSet) { + let header_idx = match func.blocks.iter().position(|b| b.id == header) { + Some(idx) => idx, + None => return, + }; + + // Collect all VarIds defined in any loop block. + let mut loop_defs: HashSet = HashSet::new(); + for block in &func.blocks { + if loop_blocks.contains(&block.id) { + for instr in &block.instructions { + if let Some(dest) = instr_dest(instr) { + loop_defs.insert(dest); + } + } + } + } + + // Fixpoint: identify invariant instructions in the header. + let mut invariant_dests: HashSet = HashSet::new(); + loop { + let mut changed = false; + for instr in &func.blocks[header_idx].instructions { + if !is_licm_hoistable(instr) { + continue; + } + let dest = match instr_dest(instr) { + Some(d) => d, + None => continue, + }; + if invariant_dests.contains(&dest) { + continue; + } + + let mut all_ops_invariant = true; + for_each_use(instr, |v| { + if loop_defs.contains(&v) && !invariant_dests.contains(&v) { + all_ops_invariant = false; + } + }); + + if all_ops_invariant { + invariant_dests.insert(dest); + changed = true; + } + } + if !changed { + break; + } + } + + if invariant_dests.is_empty() { + return; + } + + // Find or create preheader. + let preheader_id = find_or_create_preheader(func, header, loop_blocks); + + // Re-lookup indices after possible block insertion. + let header_idx = func.blocks.iter().position(|b| b.id == header).unwrap(); + let preheader_idx = func + .blocks + .iter() + .position(|b| b.id == preheader_id) + .unwrap(); + + // Move invariant instructions from header to preheader (in order). + let mut hoisted = Vec::new(); + let mut remaining = Vec::new(); + + for instr in func.blocks[header_idx].instructions.drain(..) { + if let Some(dest) = instr_dest(&instr) { + if invariant_dests.contains(&dest) { + hoisted.push(instr); + continue; + } + } + remaining.push(instr); + } + + func.blocks[header_idx].instructions = remaining; + func.blocks[preheader_idx].instructions.extend(hoisted); +} + +/// Allocate a fresh `BlockId` that doesn't conflict with existing blocks. +fn fresh_block_id(func: &IrFunction) -> BlockId { + let max_id = func.blocks.iter().map(|b| b.id.0).max().unwrap_or(0); + BlockId(max_id + 1) +} + +/// Find an existing preheader or create a new one. +/// +/// A preheader is reused if it is the sole non-loop predecessor and ends +/// with an unconditional jump to the header. Otherwise a new preheader +/// block is created and non-loop predecessors are redirected to it. +fn find_or_create_preheader( + func: &mut IrFunction, + header: BlockId, + loop_blocks: &HashSet, +) -> BlockId { + let preds = build_predecessors(func); + let header_preds = &preds[&header]; + let non_loop_preds: Vec = header_preds + .iter() + .filter(|p| !loop_blocks.contains(p)) + .copied() + .collect(); + + if non_loop_preds.is_empty() { + // Header has no non-loop predecessors (entry block or unreachable from outside). + let preheader_id = fresh_block_id(func); + func.blocks.push(IrBlock { + id: preheader_id, + instructions: vec![], + terminator: IrTerminator::Jump { target: header }, + }); + if header == func.entry_block { + func.entry_block = preheader_id; + } + return preheader_id; + } + + // Reuse if single non-loop predecessor with unconditional jump to header. + if non_loop_preds.len() == 1 { + let pred = non_loop_preds[0]; + let pred_idx = func.blocks.iter().position(|b| b.id == pred).unwrap(); + if matches!(func.blocks[pred_idx].terminator, IrTerminator::Jump { target } if target == header) + { + return pred; + } + } + + // Create a new preheader and redirect non-loop predecessors. + let preheader_id = fresh_block_id(func); + func.blocks.push(IrBlock { + id: preheader_id, + instructions: vec![], + terminator: IrTerminator::Jump { target: header }, + }); + + for block in &mut func.blocks { + if non_loop_preds.contains(&block.id) { + rewrite_terminator_target(&mut block.terminator, header, preheader_id); + } + } + + preheader_id +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{ + BinOp, IrBlock, IrFunction, IrInstr, IrTerminator, IrValue, TypeIdx, VarId, WasmType, + }; + + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + // ── No loops → no changes ──────────────────────────────────────────── + + #[test] + fn no_loop_no_change() { + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Return { + value: Some(VarId(0)), + }, + }, + ]); + + eliminate(&mut func); + + // No loops, so the const stays in block 0. + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + )); + } + + // ── Simple loop: const in header → hoisted to preheader ────────────── + + #[test] + fn simple_loop_const_hoisted() { + // B0 (entry): Jump(B1) + // B1 (header): v0 = Const(42), BranchIf(v1, B2, B3) + // B2 (body): Jump(B1) ← back edge + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // B0 is the sole non-loop predecessor with Jump → reused as preheader. + // v0 = Const(42) should be hoisted to B0. + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + + // B1 (header) should have no instructions. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── BinOp with operands from outside loop → hoisted ────────────────── + + #[test] + fn invariant_binop_hoisted() { + // B0 (entry): v0 = Const(10), v1 = Const(20), Jump(B1) + // B1 (header): v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // B2 (body): Jump(B1) + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(10), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(20), + }, + ], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v2 = BinOp should be hoisted to B0 (preheader). + assert_eq!(func.blocks[0].instructions.len(), 3); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + .. + } + )); + + // Header should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Chained invariants: const → binop using that const ─────────────── + + #[test] + fn chained_invariants_hoisted() { + // B0 (entry): v0 = Const(10), Jump(B1) + // B1 (header): v1 = Const(65536), v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // B2 (body): Jump(B1) + // B3 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(10), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(65536), + }, + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // Both v1 = Const and v2 = BinOp should be hoisted to B0. + // B0 now has: v0 = Const(10), v1 = Const(65536), v2 = Add(v0, v1). + assert_eq!(func.blocks[0].instructions.len(), 3); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(65536), + } + )); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + .. + } + )); + + // Header should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Non-hoistable instructions stay in the header ──────────────────── + + #[test] + fn side_effectful_not_hoisted() { + use crate::ir::MemoryAccessWidth; + + // B0: Jump(B1) + // B1 (header): v0 = Const(0), v1 = Load(v0), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }, + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v0 = Const is hoisted (invariant), but Load stays (not hoistable). + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { dest: VarId(0), .. } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── BinOp with operand from loop body → NOT hoisted ────────────────── + + #[test] + fn loop_dependent_not_hoisted() { + // B0: v0 = Const(1), Jump(B1) + // B1 (header): v2 = BinOp::Add(v0, v1), BranchIf(v3, B2, B3) + // v1 is defined in B2 (loop body) → v2 is NOT invariant + // B2: v1 = Const(5), Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(5), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v2 = BinOp should NOT be hoisted because v1 is defined in B2 (loop body). + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::BinOp { .. })); + } + + // ── Preheader reuse: single non-loop predecessor with Jump ─────────── + + #[test] + fn preheader_reused_when_possible() { + // B0 (entry): v0 = Const(99), Jump(B1) + // B1 (header): v1 = Const(42), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(99), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // B0 should be reused as preheader (sole non-loop pred with Jump). + // No new blocks should be created. + assert_eq!(func.blocks.len(), 4); + assert_eq!(func.blocks[0].instructions.len(), 2); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(99), + } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + } + )); + } + + // ── Preheader creation: multiple non-loop predecessors ─────────────── + + #[test] + fn preheader_created_when_needed() { + // B0 (entry): BranchIf(v0, B1, B2) + // B1: Jump(B3) + // B2: Jump(B3) + // B3 (header): v1 = Const(42), BranchIf(v2, B4, B5) + // B4 (body): Jump(B3) ← back edge + // B5 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(4), + if_false: BlockId(5), + }, + }, + IrBlock { + id: BlockId(4), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(5), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // A new preheader (B6) should be created. + assert_eq!(func.blocks.len(), 7); + + let preheader = func.blocks.iter().find(|b| b.id == BlockId(6)).unwrap(); + assert_eq!(preheader.instructions.len(), 1); + assert!(matches!( + preheader.instructions[0], + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + } + )); + assert!(matches!( + preheader.terminator, + IrTerminator::Jump { target: BlockId(3) } + )); + + // B1 and B2 should now jump to the preheader (B6). + let b1 = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert!(matches!( + b1.terminator, + IrTerminator::Jump { target: BlockId(6) } + )); + let b2 = func.blocks.iter().find(|b| b.id == BlockId(2)).unwrap(); + assert!(matches!( + b2.terminator, + IrTerminator::Jump { target: BlockId(6) } + )); + + // Header (B3) should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(3)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── GlobalGet not hoisted (depends on mutable state) ───────────────── + + #[test] + fn global_get_not_hoisted() { + use crate::ir::GlobalIdx; + + // B0: Jump(B1) + // B1 (header): v0 = GlobalGet(0), BranchIf(v1, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::GlobalGet { + dest: VarId(0), + index: GlobalIdx::new(0), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // GlobalGet should NOT be hoisted (mutable global may change each iteration). + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::GlobalGet { .. })); + } + + // ── Self-loop: header is also the back-edge source ─────────────────── + + #[test] + fn self_loop_const_hoisted() { + // B0: Jump(B1) + // B1: v0 = Const(42), BranchIf(v1, B1, B2) ← self-loop + // B2: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // Const should be hoisted to B0 (preheader). + assert_eq!(func.blocks[0].instructions.len(), 1); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── No invariant instructions → no changes ─────────────────────────── + + #[test] + fn no_invariants_no_change() { + use crate::ir::MemoryAccessWidth; + + // B0: v0 = Const(0), Jump(B1) + // B1 (header): v1 = Load(v0), BranchIf(v2, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(0), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }], + terminator: IrTerminator::BranchIf { + condition: VarId(2), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // No invariants to hoist — no new blocks, header unchanged. + assert_eq!(func.blocks.len(), 4); + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── Entry block as loop header ─────────────────────────────────────── + + #[test] + fn entry_block_loop_header() { + // B0 (entry/header): v0 = Const(42), BranchIf(v1, B1, B2) + // B1 (body): Jump(B0) ← back edge + // B2 (exit): Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::BranchIf { + condition: VarId(1), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(0) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // A preheader should be created, and entry_block updated. + assert_eq!(func.blocks.len(), 4); + let preheader_id = func.entry_block; + assert_ne!(preheader_id, BlockId(0)); + + let preheader = func.blocks.iter().find(|b| b.id == preheader_id).unwrap(); + assert_eq!(preheader.instructions.len(), 1); + assert!(matches!( + preheader.instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + } + )); + assert!(matches!( + preheader.terminator, + IrTerminator::Jump { target: BlockId(0) } + )); + + // Original header (B0) should be empty. + let header = func.blocks.iter().find(|b| b.id == BlockId(0)).unwrap(); + assert_eq!(header.instructions.len(), 0); + } + + // ── Mixed: some hoistable, some not ────────────────────────────────── + + #[test] + fn mixed_hoistable_and_non_hoistable() { + use crate::ir::MemoryAccessWidth; + + // B0: Jump(B1) + // B1 (header): v0 = Const(100), v1 = Load(v0), v2 = Const(200) + // BranchIf(v3, B2, B3) + // B2: Jump(B1) + // B3: Return + let mut func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(100), + }, + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(200), + }, + ], + terminator: IrTerminator::BranchIf { + condition: VarId(3), + if_true: BlockId(2), + if_false: BlockId(3), + }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + eliminate(&mut func); + + // v0 and v2 (Consts) should be hoisted; Load stays. + assert_eq!(func.blocks[0].instructions.len(), 2); + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(100), + } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Const { + dest: VarId(2), + value: IrValue::I32(200), + } + )); + + let header = func.blocks.iter().find(|b| b.id == BlockId(1)).unwrap(); + assert_eq!(header.instructions.len(), 1); + assert!(matches!(header.instructions[0], IrInstr::Load { .. })); + } + + // ── Single-block function → no change ──────────────────────────────── + + #[test] + fn single_block_function_no_change() { + let mut func = make_func(vec![IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }], + terminator: IrTerminator::Return { + value: Some(VarId(0)), + }, + }]); + + eliminate(&mut func); + + assert_eq!(func.blocks.len(), 1); + assert_eq!(func.blocks[0].instructions.len(), 1); + } + + // ── Dominator computation tests ────────────────────────────────────── + + #[test] + fn dominators_linear_chain() { + // B0 → B1 → B2 + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + + assert_eq!(dom[&BlockId(0)], HashSet::from([BlockId(0)])); + assert_eq!(dom[&BlockId(1)], HashSet::from([BlockId(0), BlockId(1)])); + assert_eq!( + dom[&BlockId(2)], + HashSet::from([BlockId(0), BlockId(1), BlockId(2)]) + ); + } + + #[test] + fn dominators_diamond() { + // B0 → B1, B0 → B2, B1 → B3, B2 → B3 + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::BranchIf { + condition: VarId(0), + if_true: BlockId(1), + if_false: BlockId(2), + }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Return { value: None }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + + // B3 is dominated by B0 (only common dominator of B1 and B2). + assert_eq!(dom[&BlockId(3)], HashSet::from([BlockId(0), BlockId(3)])); + } + + #[test] + fn back_edges_detected() { + // B0 → B1 → B2 → B1 (back edge) + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + let back_edges = find_back_edges(&func, &dom); + + assert_eq!(back_edges.len(), 1); + assert_eq!(back_edges[0], (BlockId(2), BlockId(1))); + } + + #[test] + fn natural_loop_blocks() { + // B0 → B1 → B2 → B3 → B1 (back edge) + // Loop = {B1, B2, B3} + let func = make_func(vec![ + IrBlock { + id: BlockId(0), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + IrBlock { + id: BlockId(1), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(2) }, + }, + IrBlock { + id: BlockId(2), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(3) }, + }, + IrBlock { + id: BlockId(3), + instructions: vec![], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }, + ]); + + let preds = build_predecessors(&func); + let dom = compute_dominators(&func, &preds); + let back_edges = find_back_edges(&func, &dom); + let loops = find_natural_loops(&back_edges, &preds); + + assert_eq!(loops.len(), 1); + let (header, loop_blocks) = &loops[0]; + assert_eq!(*header, BlockId(1)); + assert_eq!( + *loop_blocks, + HashSet::from([BlockId(1), BlockId(2), BlockId(3)]) + ); + } +} diff --git a/crates/herkos-core/src/optimizer/local_cse.rs b/crates/herkos-core/src/optimizer/local_cse.rs new file mode 100644 index 0000000..d431a19 --- /dev/null +++ b/crates/herkos-core/src/optimizer/local_cse.rs @@ -0,0 +1,574 @@ +//! Local common subexpression elimination (CSE) via value numbering. +//! +//! Within each block, identifies identical computations and replaces duplicates +//! with references to the first result. Only side-effect-free instructions are +//! considered (`BinOp`, `UnOp`, `Const`). Duplicates are replaced with +//! `Assign { dest, src: previous_result }`, which copy propagation cleans up. + +use crate::ir::{BinOp, IrFunction, IrInstr, IrValue, UnOp, VarId}; +use std::collections::HashMap; + +use super::utils::prune_dead_locals; + +// ── Value key ──────────────────────────────────────────────────────────────── + +/// Hashable representation of a pure computation for deduplication. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ValueKey { + /// Constant value (using bit-level equality for floats). + Const(ConstKey), + + /// Binary operation with operand variable IDs. + BinOp { op: BinOp, lhs: VarId, rhs: VarId }, + + /// Unary operation with operand variable ID. + UnOp { op: UnOp, operand: VarId }, +} + +/// Bit-level constant key that implements Eq/Hash correctly for floats. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ConstKey { + I32(i32), + I64(i64), + F32(u32), + F64(u64), +} + +impl From for ConstKey { + fn from(v: IrValue) -> Self { + match v { + IrValue::I32(x) => ConstKey::I32(x), + IrValue::I64(x) => ConstKey::I64(x), + IrValue::F32(x) => ConstKey::F32(x.to_bits()), + IrValue::F64(x) => ConstKey::F64(x.to_bits()), + } + } +} + +// ── Commutative op detection ───────────────────────────────────────────────── + +/// Returns true for operations where `op(a, b) == op(b, a)`. +fn is_commutative(op: &BinOp) -> bool { + matches!( + op, + BinOp::I32Add + | BinOp::I32Mul + | BinOp::I32And + | BinOp::I32Or + | BinOp::I32Xor + | BinOp::I32Eq + | BinOp::I32Ne + | BinOp::I64Add + | BinOp::I64Mul + | BinOp::I64And + | BinOp::I64Or + | BinOp::I64Xor + | BinOp::I64Eq + | BinOp::I64Ne + | BinOp::F32Add + | BinOp::F32Mul + | BinOp::F32Eq + | BinOp::F32Ne + | BinOp::F64Add + | BinOp::F64Mul + | BinOp::F64Eq + | BinOp::F64Ne + ) +} + +/// Build a `ValueKey` for a `BinOp`, normalizing operand order for commutative ops. +fn binop_key(op: BinOp, lhs: VarId, rhs: VarId) -> ValueKey { + let (lhs, rhs) = if is_commutative(&op) && lhs.0 > rhs.0 { + (rhs, lhs) + } else { + (lhs, rhs) + }; + ValueKey::BinOp { op, lhs, rhs } +} + +// ── Pass entry point ───────────────────────────────────────────────────────── + +/// Eliminates common subexpressions within each block of `func`. +pub fn eliminate(func: &mut IrFunction) { + let mut changed = false; + + for block in &mut func.blocks { + // Maps a pure computation to the first VarId that computed it. + let mut value_map: HashMap = HashMap::new(); + + for instr in &mut block.instructions { + // In strict SSA form each variable is defined exactly once, so there + // is no need to invalidate cached CSE entries on redefinition. + match instr { + IrInstr::Const { dest, value } => { + let key = ValueKey::Const(ConstKey::from(*value)); + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + IrInstr::BinOp { + dest, op, lhs, rhs, .. + } => { + let key = binop_key(*op, *lhs, *rhs); + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + IrInstr::UnOp { dest, op, operand } => { + let key = ValueKey::UnOp { + op: *op, + operand: *operand, + }; + if let Some(&first) = value_map.get(&key) { + *instr = IrInstr::Assign { + dest: *dest, + src: first, + }; + changed = true; + } else { + value_map.insert(key, *dest); + } + } + + // All other instructions are not eligible for CSE. + _ => {} + } + } + } + + if changed { + prune_dead_locals(func); + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::{BlockId, IrBlock, IrTerminator, TypeIdx, WasmType}; + + /// Helper: create a minimal IrFunction with the given blocks. + fn make_func(blocks: Vec) -> IrFunction { + IrFunction { + params: vec![], + locals: vec![], + blocks, + entry_block: BlockId(0), + return_type: None, + type_idx: TypeIdx::new(0), + } + } + + /// Helper: create a block with given instructions and a simple return terminator. + fn make_block(id: u32, instructions: Vec) -> IrBlock { + IrBlock { + id: BlockId(id), + instructions, + terminator: IrTerminator::Return { value: None }, + } + } + + #[test] + fn duplicate_binop_is_eliminated() { + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + let block = &func.blocks[0]; + assert!(matches!(block.instructions[0], IrInstr::BinOp { .. })); + assert!( + matches!( + block.instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "Duplicate BinOp should be replaced with Assign" + ); + } + + #[test] + fn commutative_binop_is_deduplicated() { + // v2 = v0 + v1, v3 = v1 + v0 → v3 should become Assign from v2 + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(1), + rhs: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "Commutative BinOp with swapped operands should be deduplicated" + ); + } + + #[test] + fn non_commutative_binop_not_deduplicated() { + // v2 = v0 - v1, v3 = v1 - v0 → different computations, keep both + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Sub, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Sub, + lhs: VarId(1), + rhs: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::BinOp { .. }), + "Non-commutative BinOp with swapped operands should NOT be deduplicated" + ); + } + + #[test] + fn duplicate_unop_is_eliminated() { + let instrs = vec![ + IrInstr::UnOp { + dest: VarId(1), + op: UnOp::I32Clz, + operand: VarId(0), + }, + IrInstr::UnOp { + dest: VarId(2), + op: UnOp::I32Clz, + operand: VarId(0), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(2), + src: VarId(1) + } + ), + "Duplicate UnOp should be replaced with Assign" + ); + } + + #[test] + fn duplicate_const_is_eliminated() { + let instrs = vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::I32(42), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::I32(42), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(0), WasmType::I32), (VarId(1), WasmType::I32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + ), + "Duplicate Const should be replaced with Assign" + ); + } + + #[test] + fn float_const_nan_bits_handled() { + // Two NaN constants with the same bit pattern should be deduplicated. + let instrs = vec![ + IrInstr::Const { + dest: VarId(0), + value: IrValue::F32(f32::NAN), + }, + IrInstr::Const { + dest: VarId(1), + value: IrValue::F32(f32::NAN), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(0), WasmType::F32), (VarId(1), WasmType::F32)]; + eliminate(&mut func); + + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(1), + src: VarId(0) + } + ), + "NaN constants with same bit pattern should be deduplicated" + ); + } + + #[test] + fn different_ops_not_deduplicated() { + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Sub, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::BinOp { .. }), + "Different operations should not be deduplicated" + ); + } + + #[test] + fn cross_block_not_deduplicated() { + // Each block should have its own value map — no cross-block CSE. + let block0 = IrBlock { + id: BlockId(0), + instructions: vec![IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Jump { target: BlockId(1) }, + }; + let block1 = IrBlock { + id: BlockId(1), + instructions: vec![IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }], + terminator: IrTerminator::Return { value: None }, + }; + + let mut func = make_func(vec![block0, block1]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + // Both should remain as BinOp (no cross-block elimination). + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!( + matches!(func.blocks[1].instructions[0], IrInstr::BinOp { .. }), + "Cross-block duplicate should NOT be eliminated" + ); + } + + /// In strict SSA form every variable is defined exactly once within a block, + /// so (v0 + v1) always refers to the same computation and can be CSE'd. + #[test] + fn ssa_unique_defs_allow_cse() { + // v2 = v0 + v1 ← first occurrence + // v3 = v0 + v1 ← identical keys with same VarIds → should be eliminated + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Add, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(2), WasmType::I32), (VarId(3), WasmType::I32)]; + eliminate(&mut func); + + // v3 should be eliminated to Assign(v3, v2). + assert!( + matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + ), + "duplicate (v0 + v1) should be CSE'd to Assign in strict SSA" + ); + } + + #[test] + fn side_effect_instructions_not_eliminated() { + // Load, Store, Call, etc. should never be CSE'd. + use crate::ir::MemoryAccessWidth; + + let instrs = vec![ + IrInstr::Load { + dest: VarId(1), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + IrInstr::Load { + dest: VarId(2), + ty: WasmType::I32, + addr: VarId(0), + offset: 0, + width: MemoryAccessWidth::Full, + sign: None, + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![(VarId(1), WasmType::I32), (VarId(2), WasmType::I32)]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::Load { .. } + )); + assert!( + matches!(func.blocks[0].instructions[1], IrInstr::Load { .. }), + "Load instructions should not be CSE'd" + ); + } + + #[test] + fn triple_duplicate_eliminates_both() { + // Three identical BinOps: second and third should become Assigns to first. + let instrs = vec![ + IrInstr::BinOp { + dest: VarId(2), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(3), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + IrInstr::BinOp { + dest: VarId(4), + op: BinOp::I32Mul, + lhs: VarId(0), + rhs: VarId(1), + }, + ]; + + let mut func = make_func(vec![make_block(0, instrs)]); + func.locals = vec![ + (VarId(2), WasmType::I32), + (VarId(3), WasmType::I32), + (VarId(4), WasmType::I32), + ]; + eliminate(&mut func); + + assert!(matches!( + func.blocks[0].instructions[0], + IrInstr::BinOp { .. } + )); + assert!(matches!( + func.blocks[0].instructions[1], + IrInstr::Assign { + dest: VarId(3), + src: VarId(2) + } + )); + assert!(matches!( + func.blocks[0].instructions[2], + IrInstr::Assign { + dest: VarId(4), + src: VarId(2) + } + )); + } +} diff --git a/crates/herkos-core/src/optimizer/mod.rs b/crates/herkos-core/src/optimizer/mod.rs index 06e0d84..d8decf0 100644 --- a/crates/herkos-core/src/optimizer/mod.rs +++ b/crates/herkos-core/src/optimizer/mod.rs @@ -20,8 +20,12 @@ mod copy_prop; mod dead_blocks; // ── Post-lowering passes ───────────────────────────────────────────────────── +mod branch_fold; mod dead_instrs; mod empty_blocks; +mod gvn; +mod licm; +mod local_cse; mod merge_blocks; /// Optimizes the pure SSA IR before phi lowering. @@ -46,16 +50,8 @@ pub fn optimize_ir(module_info: ModuleInfo, do_opt: bool) -> Result /// Optimizes the lowered IR after phi nodes have been eliminated. /// -/// Passes here operate on [`LoweredModuleInfo`] where all `IrInstr::Phi` -/// nodes have been replaced by `IrInstr::Assign` in predecessor blocks. -/// -/// ## Structural and copy passes -/// -/// Run multiple iterations of structural cleanup + copy propagation. -/// dead_instrs may leave empty blocks, which empty_blocks and merge_blocks then -/// eliminate, potentially exposing new dead instructions. copy_prop forwards the -/// assignments that lower_phis inserted. We repeat until reaching a fixed point -/// (typically 2 iterations). +/// Runs all post-lowering optimization passes: structural cleanup, redundancy +/// elimination (CSE/GVN), loop optimizations (LICM), and branch simplification. pub fn optimize_lowered_ir( module_info: LoweredModuleInfo, do_opt: bool, @@ -63,15 +59,18 @@ pub fn optimize_lowered_ir( let mut module_info = module_info; if do_opt { for func in &mut module_info.ir_functions { - // Two passes: dead_instrs may create empty blocks, and copy_prop - // may reveal new dead instrs. for _ in 0..2 { empty_blocks::eliminate(func); dead_blocks::eliminate(func)?; merge_blocks::eliminate(func); dead_blocks::eliminate(func)?; copy_prop::eliminate(func); + local_cse::eliminate(func); + gvn::eliminate(func); + dead_instrs::eliminate(func); + branch_fold::eliminate(func); dead_instrs::eliminate(func); + // licm::eliminate(func); } } }