diff --git a/conductor/CONDUCTOR_DESIGN.md b/conductor/CONDUCTOR_DESIGN.md new file mode 100644 index 000000000..2176e3da5 --- /dev/null +++ b/conductor/CONDUCTOR_DESIGN.md @@ -0,0 +1,421 @@ +# Conductor: LLM-Guided Instruction Scheduling for WaveASM + +## Context + +The WaveASM C++ backend has **no instruction scheduling pass**. Instructions are +emitted in the exact order they arrive from the upstream Python frontend. While +the Python frontend does tile-level scheduling (software pipelining, +double-buffering, stage assignment), it does not optimize **within** instruction +groups. This leaves performance on the table: + +- LDS reads (20-cycle latency) are clustered together, then MFMAs (32+ cycles) + are clustered together. Interleaving them would hide LDS latency behind MFMA + execution. +- Global loads (100-cycle latency) are not interleaved with independent compute. +- Register pressure from the instruction order can cause the linear-scan + allocator to fail (it has no spilling). + +The target latency data (`getMFMALatency()`, `getGlobalLoadLatency()`, +`getLDSLoadLatency()`) already exists in `WaveASMAttrs.td` but is **unused by +any optimization pass**. + +## Core Idea + +No intermediate DSL. The LLM sees **real WaveASM MLIR textual IR** with named +location tags on each instruction. It issues **move commands** ("move A after +B"). A deterministic engine validates structural invariants (dominance, +dependency), applies the moves, runs the rest of the WaveASM MLIR compilation +pipeline, and reports **metrics** back to the LLM. The loop repeats until the +LLM is satisfied or a budget is exhausted. + +``` +WaveASM IR (post-CSE/peephole, virtual registers) + | + v +[Tag] ── attach NameLoc("I0"), ("I1"), ... to each op + | + v +[Print] ── module.print() ──> textual MLIR + header with target/latency info + | + v +[LLM] ── reads IR, issues move commands ──> "move I5 after I2" + | "move I8 before I12" + | "swap I3 I7" + v +[Executor] ── parse commands + ── validate (dominance, pinned ops) + ── apply via Operation::moveBefore() + ── anchor pseudo-ops (pack/extract/const) + | + v +[Compile] ── run LinearScan + InsertWaitcnt + HazardMitigation on clone + | + v +[Metrics] ── collect from WaveASM pipeline passes ──> report to LLM + | + v +[LLM] ── sees metrics, decides: more moves or done + | + v +(repeat up to N rounds, then emit final assembly) +``` + +The key insight: **the LLM works on the actual IR**, not an abstraction. It can +see every instruction, every SSA value, every type. The system handles all the +mechanical work (validation, compilation, metric collection) while the LLM +handles the creative work (what to move where). + +## Instruction Tagging + +Before presenting IR to the LLM, each operation in schedulable regions gets a +stable `NameLoc` tag: + +```cpp +// During tagging pass. +int counter = 0; +for (auto &op : block.getOperations()) { + auto tag = NameLoc::get(builder.getStringAttr("I" + std::to_string(counter++))); + op.setLoc(tag); +} +``` + +In textual MLIR this appears as: + +```mlir +%0 = waveasm.buffer_load_dwordx4 %srd, %off offset:0 + : !waveasm.sreg<4>, !waveasm.vreg -> !waveasm.vreg<4> loc("I0") +%1 = waveasm.buffer_load_dwordx4 %srd, %off offset:64 + : !waveasm.sreg<4>, !waveasm.vreg -> !waveasm.vreg<4> loc("I1") +%2 = waveasm.ds_write_b128 %0, %addr + : !waveasm.vreg<4>, !waveasm.vreg loc("I2") +... +%10 = waveasm.v_mfma_f32_16x16x16_f16 %a, %b, %acc + : !waveasm.vreg, !waveasm.vreg, !waveasm.vreg<4> + -> !waveasm.vreg<4> loc("I10") +``` + +Tags are stable across rounds — they follow the operation, not the position. + +## LLM Input + +Each round, the LLM receives: + +``` +=== WaveASM Scheduling Round {N} === +TARGET: gfx942 (wave64, 512 vgpr, 106 sgpr, 512 agpr) +LATENCY: vmem=100, lds=20, mfma_16x16=32, mfma_32x32=64 + +--- IR (loop body) --- +{textual MLIR of the loop body, with loc("Ixx") tags} + +--- Metrics (from previous round, or baseline) --- +peak_vgpr: 180 +peak_sgpr: 42 +peak_agpr: 128 +nops_inserted: 3 +waitcnts_inserted: 12 +total_instructions: 87 + +--- Warnings from previous round --- +(none, or e.g.: + "warn: move I8 after I12 — moved LDS read after LDS write, same base address + info: move I3 after I7 — address computation moved away from consumer I4") + +--- Error from previous round (if any) --- +(none, or e.g.: + "Applied successfully: move I5 after I1, move I8 before I3 + Failed: swap I6 I9 — would break dominance of %2 (defined by I6, used by I7) + All moves reverted.") + +GOAL: Minimize register pressure and hide memory latency. +Respond with move commands, one per line. +``` + +## LLM Output (Move Commands) + +The LLM responds with simple text commands: + +``` +move I5 after I1 +move I8 before I3 +swap I6 I9 +``` + +### Command Set + +| Command | Semantics | +|---------|-----------| +| `move Ix after Iy` | Move op tagged Ix to immediately after op tagged Iy. | +| `move Ix before Iy` | Move op tagged Ix to immediately before op tagged Iy. | +| `swap Ix Iy` | Exchange positions of ops Ix and Iy. | +| `done` | LLM is satisfied with current schedule. | + +Three move commands + a stop signal. No DSL to learn. + +## Validation + +Move commands are applied sequentially. Each is validated **before** application: + +1. **Tag resolution**: Both tags must exist. +2. **Pinned ops**: `ConditionOp` (loop terminator), `s_barrier`, `s_endpgm` + cannot be moved. +3. **Dominance check (pre-flight)**: Simulate the move and verify every use of + every SSA value defined by the moved op is still dominated by its def. Also + check that all operands of the moved op are still dominated. +4. **Region boundary**: Cannot move ops across region boundaries (into/out of + `LoopOp` or `IfOp`). + +On the first invalid move, the system **aborts the entire round**: all moves +applied so far in this round are reverted, and the LLM receives a report +showing which moves succeeded before the failure and which move was invalid +(with the reason). For example: + +``` +--- Error --- +Applied successfully: move I5 after I1, move I8 before I3 +Failed: swap I6 I9 — would break dominance of %2 (defined by I6, used by I7) +All moves reverted. +``` + +This gives the LLM a clear signal about what went wrong and a complete picture +of which commands were valid up to the failure point. + +### Warnings + +Some moves are structurally valid (dominance holds) but **potentially +dangerous**. These are applied but reported as warnings with severity levels, +so the LLM can decide whether to keep or reconsider them. + +| Severity | Example | Meaning | +|----------|---------|---------| +| `info` | Moving address computation away from its consumer. | Unusual but harmless. | +| `warn` | Moving a memory read after a memory write. | May be valid (different addresses) but could introduce a WAR hazard. | +| `warn` | Moving an op past a barrier. | Barrier ordering is usually intentional. | +| `critical` | Moving a memory write after another write to the same buffer. | Likely WAW hazard; almost certainly wrong unless offsets are provably disjoint. | + +The system performs lightweight alias analysis where possible (e.g. comparing +buffer SRD operands and constant offsets) to reduce false warnings. When it +cannot prove safety, it warns and lets the LLM proceed — the metrics from the +compilation round will reveal whether the move was actually beneficial. + +Warnings are reported alongside metrics in the next round: + +``` +--- Warnings --- +warn: move I8 after I12 — moved LDS read (I8) after LDS write (I12), same base address +info: move I3 after I7 — address computation moved away from consumer I4 +``` + +## Metrics Collection + +After applying a round of moves, the system runs the downstream WaveASM MLIR +pipeline on a **clone** of the IR and collects metrics: + +```cpp +struct SchedulingMetrics { + int64_t peakVGPRs; + int64_t peakSGPRs; + int64_t peakAGPRs; + int64_t nopsInserted; // from HazardMitigation. + int64_t waitcntsInserted; // from InsertWaitcnt. + int64_t totalInstructions; // post-pipeline instruction count. +}; +``` + +Sources (currently internal to passes, need to be exposed): +- **Peak registers**: `computeMaxPressure()` from `Liveness.h`, or + `AllocationStats` from `LinearScanRegAlloc`. +- **NOP count**: `HazardMitigation` tracks `numNopsInserted` internally. +- **Waitcnt count**: `InsertWaitcnt` tracks tickets internally. +- **Total instructions**: count ops after all passes. + +Additional metrics can be added as needed (e.g. estimated cycle count, LDS bank +conflict potential, critical path length). + +The pipeline runs on a cloned module so the original IR is preserved for the +next round of moves. + +### Optional: GPU Profiling Integration + +When a GPU is available and profiling is enabled, the system can run the +compiled kernel through `rocprof` and feed per-instruction profiling data back +to the LLM. `rocprof` reports hit count and min/max/mean latency for each +assembly instruction. Since the assembly emitter preserves instruction order +and the tags map to assembly locations, profiling data can be correlated back +to the tagged IR. + +The profiling section is appended to the LLM input when available: + +``` +--- Profiling (rocprof, previous run) --- +I0 buffer_load_dwordx4 hits=1024 lat_min=88 lat_mean=102 lat_max=145 +I1 buffer_load_dwordx4 hits=1024 lat_min=90 lat_mean=105 lat_max=148 +I5 ds_read_b128 hits=1024 lat_min=18 lat_mean=21 lat_max=34 +I10 v_mfma_f32_16x16x16 hits=1024 lat_min=32 lat_mean=33 lat_max=35 +I15 s_waitcnt hits=1024 lat_min=0 lat_mean=45 lat_max=120 +``` + +This lets the LLM see actual stall patterns — e.g. a `s_waitcnt` with high +mean latency indicates the preceding loads are not sufficiently hidden. This +data is strictly optional; the system works without it using only the static +pipeline metrics. + +## Pipeline Integration + +``` +TranslateFromMLIR + → ScopedCSE + → Peephole + → MemoryOffsetOpt + → Canonicalizer + ScopedCSE + → [NEW] Conductor (tag + LLM loop + apply final schedule) + → LinearScan + → InsertWaitcnt + → HazardMitigation + → EmitAssembly +``` + +The Conductor pass: +1. Tags all ops with `NameLoc`. +2. Runs baseline metrics (clone → compile → collect). +3. Enters the LLM loop (up to N rounds). +4. Applies the best schedule to the real IR. +5. Hands off to LinearScan. + +## Iterative Loop Detail + +``` +baseline_metrics = compile_and_measure(clone(IR)) +best_metrics = baseline_metrics +best_state = snapshot(IR) + +for round in 1..max_rounds: + text = print_ir(IR) + format_metrics(best_metrics) + format_error(error) + commands = llm.query(text) + + if commands == ["done"]: + break + + error = null + applied = [] + for cmd in commands: + result = validate_and_apply(IR, cmd) + if result.is_error: + error = { applied: applied, failed: cmd, reason: result.message } + restore(IR, best_state) // revert entire round. + break + applied.append(cmd) + + if error: + continue // next round, LLM sees the error report. + + anchor_pseudo_ops(IR) // move pack/extract/const before earliest user. + new_metrics = compile_and_measure(clone(IR)) + + if is_better(new_metrics, best_metrics): + best_metrics = new_metrics + best_state = snapshot(IR) + else: + restore(IR, best_state) // revert if regression. + error = { applied: applied, reason: "round regressed metrics" } + +apply(IR, best_state) +``` + +`is_better()` compares metrics lexicographically: lower peak VGPRs > fewer +waitcnts > fewer nops. This ordering is configurable. + +## Handling Structured Control Flow + +- **`LoopOp` body**: The primary scheduling target. Block arguments are always + available (dominate the entire body). `ConditionOp` is pinned at end. +- **`IfOp`**: Treated as an atomic unit. The LLM can move the entire `IfOp` but + not its internal ops. The tag is on the `IfOp` itself. +- **`ProgramOp` body (prologue)**: Tagged and schedulable as a straight-line + block, but typically less benefit (SRD setup). +- **Nested regions**: Tags are scoped per region. The LLM sees one region at a + time (typically the innermost loop body). + +## Parallel Search with Checkpoints + +Individual LLM calls have high latency but the API supports many concurrent +requests. Instead of running a single sequential loop, launch **P parallel +agents**, each exploring a different scheduling trajectory from the same +starting IR. + +``` + ┌─ Agent 1 ─── round 1 ─── round 2 ─── ... + │ +tagged IR ─┼─ Agent 2 ─── round 1 ─── round 2 ─── ... + │ + ├─ Agent 3 ─── round 1 ─── round 2 ─── ... + │ + └─ Agent P ─── round 1 ─── round 2 ─── ... + │ + checkpoint +``` + +At **checkpoints** (every K rounds), collect metrics from all agents, rank +them, and keep the **top-k**. The surviving agents continue from their current +state; the rest are terminated. New agents can be spawned from the top-k states +to refill the pool, optionally with varied system prompts (e.g. "prioritize +register pressure" vs "prioritize latency hiding"). + +``` +for checkpoint in 1..num_checkpoints: + // all agents run K rounds in parallel. + results = await all_agents(K rounds each) + + // rank by metrics, keep top-k. + ranked = sort(results, by=metrics) + survivors = ranked[:top_k] + + // respawn from survivors to refill pool. + agents = [] + for state in survivors: + agents.append(continue_agent(state)) + agents.append(fork_agent(state, varied_prompt)) // optional diversity. + +best = survivors[0] +``` + +This turns the scheduling search into a **beam search** over move sequences. +The validation + metrics infrastructure is unchanged — each agent is just an +independent instance of the iterative loop. The only coordination is at +checkpoints where we compare metrics and prune. + +Parallelism is free in terms of wall-clock time (bounded by the slowest agent +per checkpoint) and the compile-and-measure step is CPU-local and fast (~ms +per clone). + +## Caching + +Final move sequences are cached by SHA-256 of the **tagged IR text** (before +any moves): + +``` +~/.waveasm-cache/conductor/ + .json → { + moves: ["move I5 after I1", "swap I6 I9"], + metrics: { peakVGPRs: 160, nops: 1, waitcnts: 8 }, + rounds: 3, + timestamp: "..." + } +``` + +On cache hit, the moves are replayed directly (still validated) without querying +the LLM. + +## Key Design Properties + +1. **No abstraction gap.** The LLM sees actual MLIR IR, not a lossy summary. +2. **Zero new languages.** Three move commands, trivial to parse and validate. +3. **Fast feedback.** Invalid moves abort the round immediately with a clear + error report; the LLM learns from mistakes without silent corruption. +4. **Iterative refinement.** The LLM observes metrics after each round and + adjusts, rather than producing a one-shot schedule. +5. **Closed loop with hardware.** Optional `rocprof` integration feeds real + per-instruction latencies back, grounding the LLM in measured data. +6. **Transparent.** Every move is logged. Easy to inspect, replay, and debug. +7. **Extensible metrics.** Adding a new metric is just a field in + `SchedulingMetrics` — no protocol changes needed. diff --git a/conductor/README.md b/conductor/README.md new file mode 100644 index 000000000..b9b7338da --- /dev/null +++ b/conductor/README.md @@ -0,0 +1,117 @@ +# Conductor + +LLM-guided instruction scheduling for WaveASM. Takes pre-scheduling WaveASM IR, +lets an LLM reorder instructions via move/swap commands, then compiles and +measures the result. See +[CONDUCTOR_DESIGN.md](../wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md) +for the full design rationale. + +## Pipeline + +``` +Python frontend (GEMM kernel) + → Wave compiler → MLIR + → waveasm-translate (pre-scheduling: CSE, peephole, mem-offset-opt) + → Tag instructions (attach loc("tag_name") to each op) + → LLM loop: present IR + metrics → get move commands → apply → compile → measure + → Final assembly +``` + +## Modules + +| File | Purpose | +|------|---------| +| `extract_ir.py` | Captures MLIR from a Wave kernel, runs pre-scheduling pipeline, computes baseline metrics. | +| `conductor.py` | `Conductor` class: tag, apply_moves, compile_to_asm, evaluate, baseline. CLI entry point. | +| `llm.py` | OpenRouter client, prompt formatting, command parsing, iterative scheduling loop. | + +## Prerequisites + +- `waveasm-translate` and `waveasm-conductor` binaries (build from `wave_lang/kernel/wave/asm/wave_asm/`). +- Python `requests` package. +- `OPENROUTER_API_KEY` env var for LLM mode. + +## Usage + +All commands should be run from the wave project root with `WAVE_CACHE_ON=0`. + +### Baseline metrics (no LLM) + +```bash +python -m conductor.conductor --metrics +``` + +### Show tagged IR + +```bash +python -m conductor.conductor --tag-only +``` + +### Apply manual moves + +```bash +python -m conductor.conductor \ + --moves "swap buffer_load_dwordx4_0 v_mfma_f32_16x16x16_f16_0" \ + --metrics +``` + +### Read moves from a file + +```bash +python -m conductor.conductor --moves-file moves.txt --metrics +``` + +### LLM-guided scheduling + +```bash +OPENROUTER_API_KEY=sk-... python -m conductor.conductor \ + --llm --max-rounds 5 --model google/gemini-2.5-flash-preview --metrics +``` + +### Extract pre-scheduling IR only + +```bash +python -m conductor.extract_ir -o /tmp/conductor_ir.mlir +python -m conductor.extract_ir --metrics +``` + +## Move commands + +| Command | Example | +|---------|---------| +| `move TAG before TAG` | `move v_add_u32_1 before v_add_u32_0` | +| `move TAG after TAG` | `move buffer_load_dwordx4_0 after ds_read_b128_2` | +| `swap TAG TAG` | `swap v_mfma_f32_16x16x16_f16_0 ds_read_b128_1` | + +## Environment variables + +| Variable | Default | Purpose | +|----------|---------|---------| +| `OPENROUTER_API_KEY` | (none) | Required for `--llm` mode. | +| `WAVE_CACHE_ON` | `1` | Set to `0` to disable caching during development. | +| `WAVE_DEFAULT_ARCH` | `gfx942` | Target GPU architecture. | +| `WAVEASM_TRANSLATE` | (auto-detect) | Override path to `waveasm-translate` binary. | +| `WAVEASM_CONDUCTOR` | (auto-detect) | Override path to `waveasm-conductor` binary. | + +## Programmatic usage + +```python +from conductor import Conductor +from conductor.extract_ir import capture_kernel_mlir, run_pre_scheduling_pipeline +from conductor.llm import run_scheduling_loop + +mlir_text, wg_size = capture_kernel_mlir() +waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size) + +c = Conductor(waveasm_ir, wg_size) + +# Baseline. +print(c.baseline()) + +# Manual moves. +print(c.evaluate(["swap buffer_load_dwordx4_0 v_mfma_f32_16x16x16_f16_0"])) + +# LLM loop. +result = run_scheduling_loop(c, max_rounds=5, model="google/gemini-2.5-flash-preview") +print(result["metrics"], result["commands"]) +``` diff --git a/conductor/SCHEDULING_GUIDE.md b/conductor/SCHEDULING_GUIDE.md new file mode 100644 index 000000000..38c3f9d03 --- /dev/null +++ b/conductor/SCHEDULING_GUIDE.md @@ -0,0 +1,182 @@ +# CDNA4 Instruction Scheduling Guide + +Reference for the Conductor LLM scheduling loop. Based on the AMD Instinct +CDNA4 ISA Reference Guide (gfx950). + +## Architecture Overview + +- Wave size: 64 threads. +- Register files: 256 arch VGPRs (V0-V255) + 256 AccVGPRs/AGPRs (A0-A255) = 512 total per wave. +- SGPRs: 16-102 per wave (VCC occupies SGPR 106-107). +- LDS: 160 KiB per CU, 64 banks x 32-bit, 1280-byte allocation granularity. +- Allocation granularity: VGPRs in groups of 8, SGPRs in groups of 16. +- Issue model: one instruction per cycle per wavefront (latency hiding via interleaving wavefronts). + +## Instruction Latencies + +| Instruction class | Latency (cycles) | Counter | +|---|---|---| +| Global load (buffer_load, global_load) | ~100 | vmcnt | +| LDS read (ds_read) | ~20 | lgkmcnt | +| LDS write (ds_write) | ~20 | lgkmcnt | +| MFMA F16/BF16 16x16x16 | 16 | — | +| MFMA F16/BF16 32x32x8 | 32 | — | +| MFMA F32 16x16x4 | 32 | — | +| MFMA F32 32x32x2 | 64 | — | +| MFMA F8/F4 16x16x128 | 16 (FP4/FP6) or 32 (FP8) | — | +| MFMA F8/F4 32x32x64 | 32 (FP4/FP6) or 64 (FP8) | — | +| scaled_mfma (MXFP4) | Same as F8/F4 above | — | +| MFMA F64 16x16x4 | 64 | — | +| VALU (non-transcendental) | 1 | — | +| VALU transcendental (exp, log, rcp, rsq, sqrt, sin, cos) | 2 | — | +| SALU | 1 | — | +| Scalar memory read | variable | lgkmcnt | + +## Waitcnt Counters + +| Counter | Bits | Max outstanding | Tracked operations | +|---|---|---|---| +| vmcnt | 6 | 63 | Global/buffer loads and stores | +| lgkmcnt | 4 | 15 | LDS ops, scalar memory, sendmsg | + +- VMEM reads/writes return in issue order. +- Scalar memory reads can return **out of order** — only `lgkmcnt(0)` is safe for SMEM. +- FLAT instructions increment both vmcnt and lgkmcnt — only `s_waitcnt 0` is safe after FLAT. +- `s_endpgm` implicitly executes `s_waitcnt 0`. + +## MFMA Dependency Rules + +### Accumulator chains (SrcC = previous vDst, same opcode, same register range) + +**Zero software NOPs required.** The hardware provides 2 implicit wait cycles. +This is the intended use pattern for matrix accumulation loops — chain MFMAs +back-to-back on the same accumulator with no stall. + +### Cross-dependencies (reading MFMA output as SrcA/SrcB or in VALU) + +These require the full output latency to elapse: + +| Scenario | Wait (NOPs) | +|---|---| +| XDL write → SrcA/B of any MFMA | 5 / 8 / 12 / 20 | +| XDL write → VALU read/write (RAW+WAW) | 5 / 8 / 12 / 20 | +| XDL write → VMEM/LDS/FLAT overlap | 5 / 8 / 12 / 20 | +| SGEMM write → SrcA/B of any MFMA | 4 / 6 / 10 / 18 | +| DGEMM 16x16x4 write → SrcA/B or VALU | 19 | + +The multiple values correspond to different output register overlap depths. + +### Cross-type forwarding + +| Scenario | Wait (NOPs) | +|---|---| +| SGEMM write → XDL SrcC (overlapping) | **0** (XDL reads SrcC 2x faster) | +| XDL write → SGEMM SrcC (overlapping) | 3 | +| v_cmpx (writes EXEC) → any V_MFMA | **4** (no exec forwarding to matrix core) | + +## Software Hazards (Table 11) + +The hardware does **not** detect these. The compiler must insert s_nop or +independent instructions. + +| First instruction | Second instruction | NOPs required | +|---|---|---| +| VALU writes SGPR | VMEM reads that SGPR | **5** | +| VALU sets VCC or EXEC | VALU reads EXECZ/VCCZ as data | **5** | +| VALU writes EXEC | VALU DPP op | **5** | +| VALU writes SGPR/VCC | v_readlane/v_writelane lane select | **4** | +| VALU writes VCC | v_div_fmas | **4** | +| v_cmpx writes EXEC | v_readlane/v_readfirstlane/v_writelane | **4** | +| VALU writes SGPR/VCC | VALU reads SGPR as constant | **2** | +| v_cmpx writes EXEC | VALU reads EXEC as constant | **2** | +| S_SETREG MODE.vskip | Any vector op | **2** | +| Trans op (exp, log, rcp, ...) | Non-trans VALU consuming result | **1** | +| VALU writes VGPR | VALU DPP reads that VGPR | **1** | +| SALU writes M0 | LDS add-TID / buffer_store_LDS | **1** | +| Mixed VCC alias access | VALU reads VCC as constant | **1** | +| FLAT/BUFFER_STORE x3/x4 | Write VGPRs holding writedata | **1** | + +Note: `s_nop N` inserts N+1 idle cycles. So `s_nop 4` = 5 NOPs. + +## Occupancy and Register Pressure + +Waves per SIMD as a function of VGPR usage (512-slot pool, 8-register granularity): + +| VGPRs used | Waves/SIMD | Waves/CU (4 SIMDs) | +|---|---|---| +| ≤64 | 8 | 32 | +| 65-72 | 7 | 28 | +| 73-80 | 6 | 24 | +| 81-96 | 5 | 20 | +| 97-128 | 4 | 16 | +| 129-168 | 3 | 12 | +| 169-256 | 2 | 8 | + +AGPRs use an independent pool with the same formula. Final occupancy is +`min(vgpr_waves, agpr_waves, sgpr_waves, lds_waves)`. + +**Key breakpoints:** Dropping below 128, 96, 80, or 64 VGPRs each adds one +wave per SIMD. For MFMA-dominant kernels, 2-4 waves/SIMD is often optimal. + +## Scheduling Strategy + +### 1. Issue global loads early + +Global loads have ~100 cycle latency. Move them as far above their consumers +as possible. Every MFMA (16-64 cycles) or LDS op placed between the load and +its `s_waitcnt vmcnt` hides latency for free. + +### 2. Interleave LDS reads with MFMA + +LDS reads take ~20 cycles. A single MFMA 16x16x16 (16 cycles) nearly covers +one LDS read. Interleaving `ds_read → mfma → ds_read → mfma` hides LDS +latency much better than `ds_read, ds_read, ..., mfma, mfma, ...`. + +### 3. Fill MFMA pipeline bubbles + +Between two MFMAs with an accumulator dependency (B.SrcC = A.vDst), the +hardware stalls until A completes. Fill this gap with: +- Independent LDS reads/writes (for the next iteration). +- Independent VALU (address computation, data formatting). +- Independent global loads (prefetch for next tile). + +### 4. Minimize live ranges + +Moving a producer closer to its consumer (or deferring a def until just +before use) reduces the number of simultaneously live VGPRs, potentially +lowering peak register pressure and enabling higher occupancy. + +### 5. Double-buffering pattern + +The ideal software-pipelined loop body: +``` +barrier +ds_read current tile from LDS +global_load next tile (prefetch) +mfma on current tile (hides global_load latency) +s_waitcnt vmcnt(0) +barrier +ds_write next tile to LDS +``` + +### 6. Metric priorities (lexicographic) + +1. **peak_vgpr** — lower = higher occupancy. +2. **s_waitcnt** — fewer waits = better latency hiding. +3. **s_nop** — fewer NOPs = fewer pipeline hazards. +4. **total_instructions** — fewer = less instruction cache pressure. + +## LDS Details + +- 160 KiB per CU, 64 banks, each 32-bit wide. +- Best case latency (no bank conflicts): 2 cycles. +- Worst case (64-way bank conflict): 64 cycles. +- A wavefront (64 threads) is dispatched over 4 sub-cycles (16 threads each). +- DS_READ2/DS_WRITE2 (64-bit extended) double per-op bandwidth. + +## Barrier Semantics + +- `s_barrier` synchronizes all wavefronts in a workgroup. +- It does NOT implicitly wait for memory — issue `s_waitcnt` before `s_barrier` + if protecting memory operations. +- Treat barriers as hard scheduling boundaries. Never move ops across barriers. diff --git a/conductor/__init__.py b/conductor/__init__.py new file mode 100644 index 000000000..0d106741a --- /dev/null +++ b/conductor/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Conductor: LLM-guided instruction scheduling for WaveASM.""" + +from conductor.conductor import Conductor, find_waveasm_conductor +from conductor.extract_ir import ( + find_waveasm_translate, + run_waveasm_translate, + run_pre_scheduling_pipeline, + run_full_pipeline, + count_asm_metrics, + capture_kernel_mlir, + capture_mxfp4_kernel_mlir, +) +from conductor.llm import run_scheduling_loop +from conductor.providers.openrouter import Stats, Counters +from conductor.tools import Param, ToolDef, ToolRegistry + +__all__ = [ + "Conductor", + "find_waveasm_conductor", + "find_waveasm_translate", + "run_waveasm_translate", + "run_pre_scheduling_pipeline", + "run_full_pipeline", + "count_asm_metrics", + "capture_kernel_mlir", + "capture_mxfp4_kernel_mlir", + "run_scheduling_loop", + "Stats", + "Counters", + "Param", + "ToolDef", + "ToolRegistry", +] diff --git a/conductor/conductor.py b/conductor/conductor.py new file mode 100644 index 000000000..fd190275e --- /dev/null +++ b/conductor/conductor.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +""" +Conductor: Python driver for WaveASM instruction scheduling experiments. + +Ties together extract_ir → tag → apply moves → post-scheduling pipeline → metrics. + +Usage: + # Baseline (no moves): + python -m conductor.conductor --metrics + + # Apply specific moves: + python -m conductor.conductor \ + --moves "swap buffer_load_dwordx4_0 v_mfma_f32_16x16x16_f16_0" --metrics + + # Show tagged IR only: + python -m conductor.conductor --tag-only + + # Read moves from a file: + python -m conductor.conductor --moves-file moves.txt --metrics +""" + +import argparse +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +# Add wave_lang to path. +wave_root = Path(__file__).parent.parent +sys.path.insert(0, str(wave_root)) + +from conductor.extract_ir import ( + run_waveasm_translate, + count_asm_metrics, + capture_kernel_mlir, + run_pre_scheduling_pipeline, +) + + +def find_waveasm_conductor() -> str: + """Find the waveasm-conductor binary.""" + env_path = os.environ.get("WAVEASM_CONDUCTOR") + if env_path and Path(env_path).exists(): + return env_path + + candidates = [ + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "build" + / "tools" + / "waveasm-conductor" + / "waveasm-conductor", + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "build" + / "bin" + / "waveasm-conductor", + ] + for p in candidates: + if p.exists(): + return str(p) + + raise FileNotFoundError( + "waveasm-conductor not found. Set WAVEASM_CONDUCTOR env var or build it." + ) + + +class Conductor: + """Encapsulates the full round-trip for instruction scheduling experiments.""" + + def __init__(self, waveasm_ir: str, workgroup_size: tuple, target: str = "gfx942"): + self.waveasm_ir = waveasm_ir + self.workgroup_size = workgroup_size + self.target = target + self._baseline_cache = None + + def tag(self) -> str: + """Run tag-instructions on the IR. Return tagged IR text.""" + flags = [ + "--waveasm-tag-instructions", + "--print-debug-locs-inline", + "--use-nameloc-as-prefix", + ] + stdout, stderr, rc = run_waveasm_translate( + self.waveasm_ir, self.workgroup_size, flags + ) + if rc != 0: + raise RuntimeError(f"tag-instructions failed:\n{stderr}") + return stdout + + def apply_moves(self, tagged_ir: str, commands: list) -> str: + """Apply CONDUCTOR move commands to tagged IR. Return reordered IR.""" + conductor = find_waveasm_conductor() + + # Prepend CONDUCTOR commands to the tagged IR. + header = "\n".join(f"// CONDUCTOR: {cmd}" for cmd in commands) + full_input = header + "\n\n" + tagged_ir + + with tempfile.NamedTemporaryFile(mode="w", suffix=".mlir", delete=False) as f: + f.write(full_input) + input_path = f.name + + try: + cmd = [ + conductor, + "--print-debug-locs-inline", + "--use-nameloc-as-prefix", + input_path, + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + if result.returncode != 0: + raise RuntimeError( + f"waveasm-conductor failed (rc={result.returncode}):\n{result.stderr}" + ) + return result.stdout + finally: + os.unlink(input_path) + + def compile_to_asm(self, ir: str) -> str: + """Run post-scheduling pipeline on already-scheduled WaveASM IR.""" + flags = [ + "--waveasm-linear-scan", + "--max-vgprs=512", + "--max-agprs=512", + "--waveasm-insert-waitcnt", + "--waveasm-hazard-mitigation", + "--emit-assembly", + ] + stdout, stderr, rc = run_waveasm_translate(ir, self.workgroup_size, flags) + if rc != 0: + raise RuntimeError(f"compile_to_asm failed:\n{stderr}") + return stdout + + def get_metrics(self, asm: str) -> dict: + """Extract metrics from assembly text.""" + return count_asm_metrics(asm) + + def evaluate(self, commands: list) -> dict: + """ + Full round-trip: tag → apply_moves → compile_to_asm → get_metrics. + + This is the main entry point for a search algorithm. + """ + _, metrics = self.evaluate_with_ir(commands) + return metrics + + def evaluate_with_ir(self, commands: list) -> tuple[str, dict]: + """Like evaluate, but also returns the reordered tagged IR.""" + tagged = self.tag() + reordered = self.apply_moves(tagged, commands) + asm = self.compile_to_asm(reordered) + return reordered, self.get_metrics(asm) + + def baseline(self) -> dict: + """Evaluate with no moves (identity schedule). Caches result.""" + if self._baseline_cache is None: + asm = self.compile_to_asm(self.waveasm_ir) + self._baseline_cache = self.get_metrics(asm) + return self._baseline_cache + + +def main(): + parser = argparse.ArgumentParser( + description="Conductor: WaveASM instruction scheduling driver." + ) + parser.add_argument( + "--moves", + nargs="*", + default=None, + help="Move commands (e.g. 'swap A B', 'move X after Y').", + ) + parser.add_argument( + "--moves-file", + type=str, + default=None, + help="Read move commands from a file (one per line).", + ) + parser.add_argument( + "--metrics", + action="store_true", + help="Print assembly metrics after scheduling.", + ) + parser.add_argument( + "--tag-only", + action="store_true", + help="Only show tagged IR, then exit.", + ) + parser.add_argument( + "--llm", + action="store_true", + help="Run the LLM-guided scheduling loop.", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="OpenRouter model ID for --llm mode.", + ) + parser.add_argument( + "--max-rounds", + type=int, + default=5, + help="Maximum LLM scheduling rounds (default: 5).", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="LLM sampling temperature (default: 0.7).", + ) + parser.add_argument( + "--reasoning-effort", + type=str, + default="low", + help="Reasoning effort for models that support it (default: low).", + ) + parser.add_argument( + "--provider", + type=str, + default="openrouter", + choices=["openrouter", "cursor"], + help="LLM provider (default: openrouter).", + ) + parser.add_argument( + "--kernel", + type=str, + default="gemm", + choices=["gemm", "mxfp4"], + help="Kernel to capture (default: gemm).", + ) + args = parser.parse_args() + + # Collect commands from both sources. + commands = [] + if args.moves: + commands.extend(args.moves) + if args.moves_file: + commands.extend( + line.strip() + for line in Path(args.moves_file).read_text().splitlines() + if line.strip() and not line.strip().startswith("#") + ) + + from conductor.extract_ir import capture_mxfp4_kernel_mlir + + capture_fn = ( + capture_mxfp4_kernel_mlir if args.kernel == "mxfp4" else capture_kernel_mlir + ) + print(f"Capturing {args.kernel} kernel MLIR...", file=sys.stderr) + mlir_text, wg_size = capture_fn() + print(f" workgroup_size: {wg_size}", file=sys.stderr) + + print("Running pre-scheduling pipeline...", file=sys.stderr) + waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size) + print(f" WaveASM IR: {len(waveasm_ir)} chars", file=sys.stderr) + + from conductor.extract_ir import get_target + + conductor = Conductor(waveasm_ir, wg_size, target=get_target()) + + if args.tag_only: + print(conductor.tag()) + return + + if args.llm: + from conductor.llm import run_scheduling_loop + + print( + f"Running LLM scheduling loop (provider={args.provider})...", + file=sys.stderr, + ) + result = run_scheduling_loop( + conductor, + max_rounds=args.max_rounds, + model=args.model, + temperature=args.temperature, + reasoning_effort=args.reasoning_effort, + provider=args.provider, + ) + print("\n=== LLM Scheduling Result ===") + print(f" rounds: {result['rounds']}") + print(f" commands: {result['commands']}") + print(" baseline:") + for k, v in result["baseline_metrics"].items(): + print(f" {k}: {v}") + print(" best:") + for k, v in result["metrics"].items(): + print(f" {k}: {v}") + usage = result.get("usage") + if usage: + print( + f" tokens: {usage.tokens} (in={usage.input_tokens} out={usage.output_tokens})" + ) + print(f" cost: ${usage.cost:.4f}") + return + + if commands: + print(f"Applying {len(commands)} move(s)...", file=sys.stderr) + metrics = conductor.evaluate(commands) + else: + print("No moves specified, running baseline...", file=sys.stderr) + metrics = conductor.baseline() + + if args.metrics: + print("\n=== Metrics ===") + for k, v in metrics.items(): + print(f" {k}: {v}") + + +if __name__ == "__main__": + main() diff --git a/conductor/extract_ir.py b/conductor/extract_ir.py new file mode 100644 index 000000000..53c073f30 --- /dev/null +++ b/conductor/extract_ir.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +""" +Extract WaveASM MLIR IR at the Conductor scheduling stage. + +Defines a GEMM kernel inline, runs the Wave compilation pipeline to produce +input MLIR, then runs waveasm-translate with only the pre-scheduling passes +(CSE, peephole, memory-offset-opt). The output is the WaveASM IR that the +Conductor would see — post-optimization, pre-register-allocation. + +Usage: + python -m conductor.extract_ir -o /tmp/conductor_ir.mlir + python -m conductor.extract_ir --metrics + python -m conductor.extract_ir --dump-input-mlir + +Requires: + - waveasm-translate binary (auto-detected from build/) + - WAVE_CACHE_ON=0 recommended +""" + +import argparse +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +# Add wave_lang to path. +wave_root = Path(__file__).parent.parent +sys.path.insert(0, str(wave_root)) + + +def find_waveasm_translate() -> str: + """Find the waveasm-translate binary.""" + env_path = os.environ.get("WAVEASM_TRANSLATE") + if env_path and Path(env_path).exists(): + return env_path + + candidates = [ + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "build" + / "tools" + / "waveasm-translate" + / "waveasm-translate", + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "build" + / "bin" + / "waveasm-translate", + ] + for p in candidates: + if p.exists(): + return str(p) + + raise FileNotFoundError( + "waveasm-translate not found. Set WAVEASM_TRANSLATE env var or build it." + ) + + +def get_target() -> str: + """Get the target architecture.""" + return os.environ.get("WAVE_DEFAULT_ARCH", "gfx942") + + +def capture_mxfp4_kernel_mlir() -> tuple: + """Capture MLIR from a double-buffered 4-wave MXFP4 GEMM kernel. + + Returns (mlir_text, workgroup_size). + """ + from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm_preshuffle_b + from wave_lang.kernel.wave.schedules import get_mxfp4_asymmetric_schedule + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + + # Same config as test_dbuf_4wave_mxfp4_gemm_cpp_backend. + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape=(1024, 1024, 8192), + block_shape=(256, 256, 256), + wave_shape=(1, 4), + ) + schedule = get_mxfp4_asymmetric_schedule() + + options.backend = "asm" + options.wave_runtime = True + options.compile_to_mlir = False + options = set_default_run_config(options) + + # Reuse the test helper that handles opsel + canonicalize. + sys.path.insert( + 0, + str( + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "test" + / "e2e" + ), + ) + from waveasm_e2e import capture_wave_kernel_info + + info = capture_wave_kernel_info(options, gemm, schedule=schedule) + return info.mlir_text, info.workgroup_size + + +def capture_kernel_mlir() -> tuple: + """Capture MLIR from a manually-scheduled 8-wave GEMM kernel. + + Uses get_tagged_gemm + get_two_pp_cluster_schedule for a 2-stage + pipelined prefetch with cluster reordering and wave staggering. + + Returns (mlir_text, workgroup_size). + """ + from wave_lang.kernel.wave.schedules import get_two_pp_cluster_schedule + from wave_lang.kernel.wave.schedules.gemm_two_pp_cluster import get_tagged_gemm + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + + gemm, options = get_tagged_gemm( + shape=(4096, 4096, 4096), + block_shape=(128, 256, 64), + ) + schedule = get_two_pp_cluster_schedule() + + options.backend = "asm" + options.wave_runtime = True + options.compile_to_mlir = False + options = set_default_run_config(options) + + sys.path.insert( + 0, + str( + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "test" + / "e2e" + ), + ) + from waveasm_e2e import capture_wave_kernel_info + + info = capture_wave_kernel_info(options, gemm, schedule=schedule) + return info.mlir_text, info.workgroup_size + + +def run_waveasm_translate( + mlir_text: str, workgroup_size: tuple, extra_flags: list = None +) -> tuple: + """ + Run waveasm-translate with given flags. + + Returns (stdout, stderr, returncode). + """ + translate = find_waveasm_translate() + target = get_target() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".mlir", delete=False) as f: + f.write(mlir_text) + input_path = f.name + + try: + cmd = [ + translate, + f"--target={target}", + f"--workgroup-size-x={workgroup_size[0]}", + f"--workgroup-size-y={workgroup_size[1]}", + f"--workgroup-size-z={workgroup_size[2]}", + ] + if extra_flags: + cmd.extend(extra_flags) + cmd.append(input_path) + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + return result.stdout, result.stderr, result.returncode + finally: + os.unlink(input_path) + + +def run_pre_scheduling_pipeline(mlir_text: str, workgroup_size: tuple) -> str: + """ + Run waveasm-translate with only pre-scheduling passes. + + Produces WaveASM IR after: + TranslateFromMLIR -> ScopedCSE -> Peephole -> MemoryOffsetOpt -> Canonicalizer -> ScopedCSE + + But before: + LinearScan, InsertWaitcnt, HazardMitigation, EmitAssembly + """ + flags = [ + "--mlir-cse", + "--waveasm-scoped-cse", + "--waveasm-peephole", + "--waveasm-memory-offset-opt", + # Stop here — no regalloc, no waitcnt, no hazard, no assembly. + ] + stdout, stderr, rc = run_waveasm_translate(mlir_text, workgroup_size, flags) + if rc != 0: + print(f"waveasm-translate (pre-scheduling) failed:\n{stderr}", file=sys.stderr) + sys.exit(1) + return stdout + + +def run_full_pipeline(mlir_text: str, workgroup_size: tuple) -> str: + """Run the full pipeline and return assembly text.""" + flags = [ + "--mlir-cse", + "--waveasm-scoped-cse", + "--waveasm-peephole", + "--waveasm-memory-offset-opt", + "--waveasm-linear-scan", + "--max-vgprs=512", + "--max-agprs=512", + "--waveasm-insert-waitcnt", + "--waveasm-hazard-mitigation", + "--emit-assembly", + ] + stdout, stderr, rc = run_waveasm_translate(mlir_text, workgroup_size, flags) + if rc != 0: + print(f"Full pipeline failed:\n{stderr}", file=sys.stderr) + return "" + return stdout + + +def count_asm_metrics(asm_text: str) -> dict: + """Extract basic metrics from assembly text.""" + lines = asm_text.split("\n") + metrics = { + "total_instructions": 0, + "s_waitcnt": 0, + "s_nop": 0, + "mfma": 0, + "buffer_load": 0, + "ds_read": 0, + "ds_write": 0, + } + for line in lines: + stripped = line.strip() + if not stripped or stripped.startswith(("//", ";", ".")): + continue + if stripped.endswith(":"): + continue + metrics["total_instructions"] += 1 + if "s_waitcnt" in stripped: + metrics["s_waitcnt"] += 1 + if "s_nop" in stripped: + metrics["s_nop"] += 1 + if "mfma" in stripped: + metrics["mfma"] += 1 + if "buffer_load" in stripped: + metrics["buffer_load"] += 1 + if "ds_read" in stripped: + metrics["ds_read"] += 1 + if "ds_write" in stripped: + metrics["ds_write"] += 1 + + # Extract register counts from kernel descriptor. + for line in lines: + if ".amdhsa_next_free_vgpr" in line: + metrics["peak_vgpr"] = int(line.split()[-1]) + if ".amdhsa_next_free_sgpr" in line: + metrics["peak_sgpr"] = int(line.split()[-1]) + + return metrics + + +def main(): + parser = argparse.ArgumentParser( + description="Extract WaveASM IR at the Conductor scheduling stage." + ) + parser.add_argument( + "--dump-input-mlir", + action="store_true", + help="Also dump the input MLIR (before waveasm-translate).", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=None, + help="Write WaveASM IR to file instead of stdout.", + ) + parser.add_argument( + "--metrics", + action="store_true", + help="Also run full pipeline and print baseline metrics.", + ) + parser.add_argument( + "--kernel", + type=str, + default="gemm", + choices=["gemm", "mxfp4"], + help="Kernel to capture (default: gemm).", + ) + args = parser.parse_args() + + capture_fn = ( + capture_mxfp4_kernel_mlir if args.kernel == "mxfp4" else capture_kernel_mlir + ) + print(f"Capturing {args.kernel} kernel MLIR...", file=sys.stderr) + mlir_text, wg_size = capture_fn() + print(f" workgroup_size: {wg_size}", file=sys.stderr) + print(f" input MLIR: {len(mlir_text)} chars", file=sys.stderr) + + if args.dump_input_mlir: + print("=== Input MLIR (before waveasm-translate) ===") + print(mlir_text) + print("=== End Input MLIR ===\n") + + print("Running pre-scheduling pipeline...", file=sys.stderr) + waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size) + print(f" WaveASM IR: {len(waveasm_ir)} chars", file=sys.stderr) + + op_count = sum( + 1 + for line in waveasm_ir.split("\n") + if "waveasm." in line and not line.strip().startswith("//") + ) + print(f" WaveASM ops: {op_count}", file=sys.stderr) + + if args.output: + Path(args.output).write_text(waveasm_ir) + print(f" Written to: {args.output}", file=sys.stderr) + else: + print(waveasm_ir) + + if args.metrics: + print("\nRunning full pipeline for baseline metrics...", file=sys.stderr) + asm_text = run_full_pipeline(mlir_text, wg_size) + if asm_text: + metrics = count_asm_metrics(asm_text) + print("\n=== Baseline Metrics ===") + for k, v in metrics.items(): + print(f" {k}: {v}") + + +if __name__ == "__main__": + main() diff --git a/conductor/llm.py b/conductor/llm.py new file mode 100644 index 000000000..8501e34f4 --- /dev/null +++ b/conductor/llm.py @@ -0,0 +1,337 @@ +"""Iterative LLM scheduling loop for Conductor.""" + +import difflib +import json +import sys +from collections.abc import Callable +from contextlib import nullcontext + +from conductor.providers.openrouter import ( + Counters, + Message, + Stats, + chat as openrouter_chat, +) +from conductor.tools import Param, ToolRegistry + + +def _default_log(msg: str) -> None: + """Default logger: print to stderr without trailing newline.""" + print(msg, file=sys.stderr, end="", flush=True) + + +SYSTEM_PROMPT = """\ +You are an expert GPU instruction scheduler for AMD CDNA architectures. + +You will receive WaveASM MLIR IR with tagged instructions (loc("tag_name")). +Your job is to reorder instructions to hide memory latency and reduce register pressure. + +Latencies (cycles): + global_load/buffer_load: ~100 LDS (ds_read/ds_write): ~20 + MFMA F16 16x16x16: 16 MFMA F16 32x32x8: 32 + MFMA F8/F4 16x16x128: 16-32 MFMA F8/F4 32x32x64: 32-64 + scaled_mfma (MXFP4): same as above + VALU: 1 (transcendentals: 2) SALU: 1 + +Scheduling strategy: +- Issue global loads as early as possible, defer s_waitcnt to just before use. +- Interleave LDS reads with MFMA: ~20 cycles of MFMA hides one LDS read. +- Fill cycles between dependent MFMAs with independent loads or VALU. +- MFMA accumulator chains (same opcode, SrcC=prev vDst) need 0 extra NOPs. +- Fewer s_waitcnt and s_nop in the final assembly = better schedule. +- Lower peak VGPRs = higher occupancy (key breakpoints: 128, 96, 80, 64). + +You have an `evaluate_moves` tool. Call it with a list of move command strings. +Each command is one of: + "move TAG_A after TAG_B" + "move TAG_A before TAG_B" + "swap TAG_A TAG_B" + +The tool will apply the moves, compile, and return metrics + updated IR. + +Constraints: +- All moves must stay within the same basic block. Never move across blocks. +- Moves must preserve SSA dominance: a value must be defined before all its uses. + Before proposing any move, trace the SSA def-use chains for the instruction + you want to move. Check: (1) every operand (%val) it consumes is still + defined above it after the move, and (2) every result it produces is still + above all its users after the move. If either check fails, the move is + illegal — pick a different one. +- Pinned ops (s_endpgm, s_barrier, condition) cannot be moved. + +Work incrementally: try 1-3 moves per tool call, read the resulting metrics, \ +then decide your next moves. You can call the tool multiple times. + +DO NOT try to build too long of a plan at once. Try 1-3 moves and see what changes. + +Keep your reasoning brief. For each thinking step, write at most one short \ +sentence. Do not re-analyze the entire IR — focus only on the specific move \ +you are considering and its immediate SSA neighbors. + +When you are satisfied with the schedule or have no more ideas, call the `done` \ +tool to finish.\ +""" + + +def format_initial_prompt( + tagged_ir: str, + baseline_metrics: dict, + target: str = "gfx942", +) -> str: + """Format the initial user message with IR and baseline metrics.""" + parts = [ + f"TARGET: {target} (wave64, 256 vgpr + 256 agpr, 102 sgpr)", + "", + "--- Tagged IR ---", + tagged_ir.strip(), + "", + "--- Baseline Metrics ---", + ] + for k, v in baseline_metrics.items(): + parts.append(f" {k}: {v}") + parts.extend( + [ + "", + "GOAL: Reduce s_waitcnt and s_nop count (better latency hiding),", + "then reduce peak VGPRs (higher occupancy).", + "Use the evaluate_moves tool to try reorderings.", + ] + ) + return "\n".join(parts) + + +_NUDGE_NATIVE = ( + "You must use the evaluate_moves tool to test " + "scheduling ideas, or call done when finished. " + "Do not write tool calls as text." +) +_NUDGE_TEXT = ( + "You must output a ```json block to evaluate moves or call done. " + "Use the format described in the OUTPUT FORMAT section." +) + + +def run_scheduling_loop( + conductor, + max_rounds: int = 10, + model: str | None = None, + temperature: float = 0.7, + reasoning_effort: str | None = "high", + provider: str = "openrouter", + log: Callable[[str], None] = _default_log, +) -> dict: + """ + Run the iterative LLM scheduling loop with tool use. + + The LLM reasons in natural language and calls evaluate_moves to test + scheduling ideas. Conversation history (including reasoning) is preserved + across rounds. + + Args: + provider: "openrouter" (default) or "cursor". + + Returns dict with keys: metrics, commands, rounds, baseline_metrics, usage. + """ + # Provider setup. + if provider == "cursor": + from conductor.providers import cursor_agent + + chat_fn = cursor_agent.chat + system_prompt = SYSTEM_PROMPT + cursor_agent.TOOL_CALL_FORMAT + nudge_msg = _NUDGE_TEXT + else: + chat_fn = openrouter_chat + system_prompt = SYSTEM_PROMPT + nudge_msg = _NUDGE_NATIVE + + log(f"Provider: {provider}, model: {model}\n") + log("Computing baseline metrics...\n") + baseline = conductor.baseline() + log(f" baseline: {baseline}\n") + + initial_ir = conductor.tag() + current_ir = initial_ir + log(f" --- Tagged IR ---\n{current_ir.strip()}\n --- End IR ---\n") + + best_metrics = dict(baseline) + all_commands: list[str] = [] + finished = False + + # Build tool registry with closures over loop state. + registry = ToolRegistry() + + def _evaluate_moves(moves: list[str], summary: str = "") -> str: + nonlocal best_metrics, all_commands, current_ir + if summary: + log(f" [reason] {summary}\n") + log(f" [tool] evaluate_moves({moves})\n") + try: + reordered_ir = conductor.apply_moves(current_ir, moves) + except RuntimeError as e: + log(f" [error] {e}\n") + return json.dumps({"error": str(e)}) + try: + asm = conductor.compile_to_asm(reordered_ir) + metrics = conductor.get_metrics(asm) + except RuntimeError as e: + log(f" [error] {e}\n") + return json.dumps({"error": str(e)}) + log(f" [result] {metrics}\n") + # Moves succeeded — update state. + current_ir = reordered_ir + all_commands.extend(moves) + improved = _is_better(metrics, best_metrics) + if improved: + log(" Improvement!\n") + best_metrics = metrics + result = { + "metrics": metrics, + "improved": improved or metrics == best_metrics, + "updated_ir": reordered_ir.strip(), + } + if summary: + result["summary"] = summary + return json.dumps(result) + + def _done(summary: str) -> str: + nonlocal finished + log(f" [done] {summary}\n") + finished = True + return json.dumps({"status": "ok"}) + + registry.add( + name="evaluate_moves", + description=( + "Apply a list of move/swap commands to the tagged IR, " + "compile through the post-scheduling pipeline, and return " + "assembly metrics. Commands are applied in order." + ), + params=[ + Param( + name="moves", + description=( + "List of move commands, e.g. " + '["move tag_A after tag_B", "swap tag_C tag_D"].' + ), + type="array", + items={"type": "string"}, + ), + Param( + name="summary", + description=( + "Brief explanation of why these moves should help " + "(e.g. 'hide LDS latency by interleaving ds_read with mfma')." + ), + ), + ], + func=_evaluate_moves, + ) + registry.add( + name="done", + description=( + "Call this when you are finished scheduling. " + "Provide a short summary of what you tried." + ), + params=[ + Param(name="summary", description="Brief summary of scheduling attempts."), + ], + func=_done, + ) + + messages: list[Message] = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": format_initial_prompt( + current_ir, baseline, target=conductor.target + ), + }, + ] + + use_native_tools = provider != "cursor" + stats_ctx = Stats() if use_native_tools else nullcontext(None) + session = None # Opaque handle + + with stats_ctx as stats: + for round_num in range(1, max_rounds + 1): + log(f"\n--- Round {round_num}/{max_rounds} ---\n") + + response = chat_fn( + messages, + model=model, + temperature=temperature, + reasoning_effort=reasoning_effort, + tools=registry.definitions() if use_native_tools else None, + log=log, + session=session, + ) + session = response.pop("session", session) + messages.append(response) + + content = response.get("content", "") + if content: + log(f" [model] {content}\n") + + tool_calls = response.get("tool_calls") + if not tool_calls: + log(" [retry] No tool call, nudging model...\n") + messages.append({"role": "user", "content": nudge_msg}) + continue + + for tc in tool_calls: + result = registry.execute( + tc["function"]["name"], + tc["function"]["arguments"], + ) + messages.append( + { + "role": "tool", + "tool_call_id": tc["id"], + "content": result, + } + ) + + if finished: + break + + usage: Counters | None = stats.counters if stats else None + if usage: + log( + f"\n=== Usage ===\n" + f" tokens: {usage.tokens} (in={usage.input_tokens} out={usage.output_tokens})\n" + f" cost: ${usage.cost:.4f}\n" + ) + + ir_diff = _context_diff(initial_ir, current_ir) + if ir_diff: + log(f"\n=== IR Diff (original → final) ===\n{ir_diff}\n") + else: + log("\n=== No IR changes ===\n") + + return { + "metrics": best_metrics, + "commands": all_commands, + "rounds": round_num, + "baseline_metrics": baseline, + "usage": usage, + } + + +def _context_diff(before: str, after: str, n: int = 10) -> str: + """Return a unified diff of changed lines with ±n lines of context.""" + a = before.splitlines(keepends=True) + b = after.splitlines(keepends=True) + diff = difflib.unified_diff(a, b, fromfile="before", tofile="after", n=n) + return "".join(diff) + + +def _is_better(new: dict, old: dict) -> bool: + """Compare metrics: lower VGPRs > fewer waitcnts > fewer nops > fewer instructions.""" + for key in ("peak_vgpr", "s_waitcnt", "s_nop", "total_instructions"): + nv = new.get(key, 0) + ov = old.get(key, 0) + if nv < ov: + return True + if nv > ov: + return False + return False diff --git a/conductor/providers/__init__.py b/conductor/providers/__init__.py new file mode 100644 index 000000000..423b576ea --- /dev/null +++ b/conductor/providers/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""LLM provider backends for Conductor.""" diff --git a/conductor/providers/cursor_agent.py b/conductor/providers/cursor_agent.py new file mode 100644 index 000000000..0ce879eb9 --- /dev/null +++ b/conductor/providers/cursor_agent.py @@ -0,0 +1,214 @@ +"""Cursor Agent CLI provider for Conductor. + +Uses ``cursor-agent --print`` as a local LLM backend. Multi-turn +conversations are maintained via ``--resume SESSION_ID``. + +The model does not receive native tool schemas. Instead, the system +prompt instructs it to output fenced JSON blocks which we parse into +synthetic ``tool_calls`` compatible with the OpenRouter provider +interface. +""" + +import json +import re +import shutil +import subprocess +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +Message = dict[str, Any] +DEFAULT_MODEL = "sonnet-4.6" + +# Appended to the system prompt so the model outputs parseable JSON +# instead of relying on native function-calling. +TOOL_CALL_FORMAT = """ + +OUTPUT FORMAT — you do NOT have native function-calling tools. +Instead, to invoke a tool output EXACTLY ONE fenced JSON block per response: + +To evaluate moves: +```json +{"action": "evaluate_moves", "moves": ["move TAG_A after TAG_B"], "summary": "brief reason"} +``` + +To finish: +```json +{"action": "done", "summary": "what you tried and the outcome"} +``` + +Rules: +- One JSON block per response, then STOP and wait for the result. +- Do NOT use Shell, Read, Write, Grep, or any other built-in tools. +- Do NOT output additional commentary after the JSON block. +""" + + +def _find_binary() -> str: + """Locate the cursor-agent binary.""" + path = shutil.which("cursor-agent") + if path: + return path + raise FileNotFoundError( + "cursor-agent not found in PATH. " + "Install: curl https://cursor.com/install -fsS | bash" + ) + + +def _build_text(messages: list[Message], start: int) -> str: + """Format messages[start:] as plain text for cursor-agent. + + Assistant messages are skipped (already in the session history). + Tool results are prefixed with ``[Tool Result]``. + """ + parts: list[str] = [] + for msg in messages[start:]: + role = msg["role"] + content = msg.get("content", "") + if role == "assistant": + continue + if role == "tool": + parts.append(f"[Tool Result]\n{content}") + else: + parts.append(content) + return "\n\n".join(parts) + + +def _parse_tool_calls(text: str) -> list[dict] | None: + """Extract tool calls from fenced JSON blocks in model output.""" + # Fenced ```json ... ``` blocks. + pattern = r"```(?:json)?\s*\n(\{[^`]*?\"action\"[^`]*?\})\s*\n```" + matches = re.findall(pattern, text, re.DOTALL) + if not matches: + # Bare JSON on its own line. + matches = re.findall(r'^(\{"action"\s*:.+\})$', text, re.MULTILINE) + tool_calls: list[dict] = [] + for i, raw in enumerate(matches): + try: + obj = json.loads(raw) + except json.JSONDecodeError: + continue + action = obj.pop("action", None) + if not action: + continue + tool_calls.append( + { + "id": f"cursor_{i}", + "type": "function", + "function": { + "name": action, + "arguments": json.dumps(obj), + }, + } + ) + return tool_calls or None + + +def _run( + prompt: str, + model: str, + session_id: str | None, + log: Callable[[str], None], +) -> tuple[str, str]: + """Invoke cursor-agent in headless mode. Returns (session_id, content).""" + cmd = [ + _find_binary(), + "--print", + "--output-format", + "stream-json", + "--model", + model, + "--force", + "--trust", + ] + if session_id: + cmd.extend(["--resume", session_id]) + + log( + f" [cmd] cursor-agent --model {model} {'--resume ' + session_id if session_id else '--new'} ({len(prompt)} chars)\n" + ) + # Pipe prompt via stdin to avoid OS argument length limits. + proc = subprocess.run( + cmd, input=prompt, capture_output=True, text=True, timeout=300 + ) + + sid = session_id or "" + content = "" + for line in proc.stdout.strip().split("\n"): + line = line.strip() + if not line: + continue + try: + ev = json.loads(line) + except json.JSONDecodeError: + continue + etype = ev.get("type") + if etype == "system" and ev.get("subtype") == "init": + sid = ev.get("session_id", sid) + elif etype == "assistant": + # Last assistant message wins. + for block in ev.get("message", {}).get("content", []): + if isinstance(block, dict): + content = block.get("text", content) + elif isinstance(block, str): + content = block + elif etype == "result": + if event_is_error(ev): + log(f" [cursor-agent error] {ev.get('result', '')}\n") + if not content: + content = ev.get("result", "") + return sid, content + + +def event_is_error(ev: dict) -> bool: + """Check if a stream-json event signals an error.""" + return ev.get("is_error", False) or ev.get("subtype") == "error" + + +@dataclass +class Session: + """Opaque session handle passed between caller and provider.""" + + id: str | None = None + sent: int = 0 + + +def chat( + messages: list[Message], + model: str | None = None, + temperature: float = 0.7, + max_tokens: int = 2048, + reasoning_effort: str | None = None, + tools: list[dict] | None = None, + log: Callable[[str], None] | None = None, + session: Session | None = None, +) -> Message: + """Send a chat turn via cursor-agent CLI. + + The caller owns the ``Session`` object and passes it back on each + call. Tool calls are extracted from fenced JSON blocks in the + model's text output. + """ + model = model or DEFAULT_MODEL + if log is None: + log = lambda _: None + if session is None: + session = Session() + + prompt = _build_text(messages, session.sent) + + sid, content = _run(prompt, model, session.id, log) + session.id = sid + session.sent = len(messages) + + if content: + preview = content[:300] + "..." if len(content) > 300 else content + log(f" [response] {preview}\n") + + tool_calls = _parse_tool_calls(content) + result: Message = {"role": "assistant", "content": content} + if tool_calls: + result["tool_calls"] = tool_calls + result["session"] = session + + return result diff --git a/conductor/providers/openrouter.py b/conductor/providers/openrouter.py new file mode 100644 index 000000000..62ae61bf8 --- /dev/null +++ b/conductor/providers/openrouter.py @@ -0,0 +1,311 @@ +"""OpenRouter API client for Conductor. + +Handles streaming chat completions, retries, and usage tracking. +""" + +import json +import os +import threading +import time +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from types import TracebackType +from typing import Any + +import requests + +API_KEY: str = os.environ.get("OPENROUTER_API_KEY", "") +BASE_URL: str = "https://openrouter.ai/api/v1" +DEFAULT_MODEL: str = "deepseek/deepseek-v3.2" + +_REQUEST_TIMEOUT = 120 +_MAX_RETRIES = 3 +_RETRY_BACKOFF = 2.0 + +Message = dict[str, Any] + + +# --- Monotonic API usage counters (thread-safe, per-model). --- + + +@dataclass +class Counters: + """API usage snapshot.""" + + tokens: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + cost: float = 0.0 + + +_counters_lock = threading.Lock() +_counters: defaultdict[str, Counters] = defaultdict(Counters) + + +def _record_usage(model: str, usage: dict[str, Any]) -> None: + """Accumulate token and cost from an API response.""" + tokens = int(usage.get("total_tokens", 0)) + input_tokens = int(usage.get("prompt_tokens", 0)) + output_tokens = int(usage.get("completion_tokens", 0)) + cost = usage.get("cost") + with _counters_lock: + c = _counters[model] + c.tokens += tokens + c.input_tokens += input_tokens + c.output_tokens += output_tokens + if cost is not None: + c.cost += float(cost) + + +class Stats: + """Context manager that captures API usage over a scope. + + Snapshots the monotonic per-model counters on entry. The ``counters`` + property returns an aggregate delta; ``per_model`` returns per-model deltas. + """ + + def __init__(self) -> None: + self._start: dict[str, Counters] = {} + + def __enter__(self) -> "Stats": + with _counters_lock: + self._start = { + m: Counters(c.tokens, c.input_tokens, c.output_tokens, c.cost) + for m, c in _counters.items() + } + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + pass + + @property + def counters(self) -> Counters: + """Aggregate delta across all models since entering the context.""" + with _counters_lock: + tok = sum(c.tokens for c in _counters.values()) - sum( + c.tokens for c in self._start.values() + ) + inp = sum(c.input_tokens for c in _counters.values()) - sum( + c.input_tokens for c in self._start.values() + ) + out = sum(c.output_tokens for c in _counters.values()) - sum( + c.output_tokens for c in self._start.values() + ) + cst = sum(c.cost for c in _counters.values()) - sum( + c.cost for c in self._start.values() + ) + return Counters(tokens=tok, input_tokens=inp, output_tokens=out, cost=cst) + + @property + def per_model(self) -> defaultdict[str, Counters]: + """Per-model deltas since entering the context.""" + with _counters_lock: + result: defaultdict[str, Counters] = defaultdict(Counters) + for model, current in _counters.items(): + start = self._start.get(model, Counters()) + delta = Counters( + tokens=current.tokens - start.tokens, + input_tokens=current.input_tokens - start.input_tokens, + output_tokens=current.output_tokens - start.output_tokens, + cost=current.cost - start.cost, + ) + if delta.tokens or delta.cost: + result[model] = delta + return result + + +_TRANSIENT_ERRORS = ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + requests.exceptions.ConnectTimeout, +) + + +def _with_retry( + func: Callable[..., Any], + log: Callable[[str], None], + *args: Any, + **kwargs: Any, +) -> Any: + """Call *func* with retries on transient network errors and 5xx.""" + for attempt in range(_MAX_RETRIES): + try: + return func(*args, **kwargs) + except requests.HTTPError as exc: + resp = exc.response + if resp is not None and resp.status_code >= 500: + if attempt < _MAX_RETRIES - 1: + wait = _RETRY_BACKOFF * (attempt + 1) + log( + f"\n [retry] server {resp.status_code}, waiting {wait:.0f}s...\n" + ) + time.sleep(wait) + continue + raise + except _TRANSIENT_ERRORS: + if attempt < _MAX_RETRIES - 1: + wait = _RETRY_BACKOFF * (attempt + 1) + log(f"\n [retry] connection error, waiting {wait:.0f}s...\n") + time.sleep(wait) + else: + raise + raise RuntimeError("Unreachable") + + +def _stream_request( + payload: dict[str, Any], + on_token: Callable[[str], None] | None, + on_thinking: Callable[[str], None] | None, +) -> Message: + """Execute a streaming chat request and assemble the response message.""" + resp = requests.post( + f"{BASE_URL}/chat/completions", + headers={"Authorization": f"Bearer {API_KEY}"}, + json=payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + if resp.status_code >= 400: + raise requests.HTTPError( + f"{resp.status_code} {resp.reason}: {resp.text}", + response=resp, + ) + + content_chunks: list[str] = [] + reasoning_chunks: list[str] = [] + tool_calls_by_index: dict[int, dict[str, Any]] = {} + usage: dict[str, Any] | None = None + + for line in resp.iter_lines(): + if not line or not line.startswith(b"data: "): + continue + data = line[6:] + if data == b"[DONE]": + break + chunk = json.loads(data) + if "usage" in chunk: + usage = chunk["usage"] + choices = chunk.get("choices") + if not choices: + continue + delta = choices[0].get("delta", {}) + + # Reasoning tokens (two OpenRouter formats). + reasoning_texts: list[str] = [] + for detail in delta.get("reasoning_details", []): + text = detail.get("text", "") + if text: + reasoning_texts.append(text) + rc = delta.get("reasoning_content", "") + if rc: + reasoning_texts.append(rc) + for text in reasoning_texts: + if on_thinking is not None: + on_thinking(text) + reasoning_chunks.append(text) + + # Content tokens. + token = delta.get("content", "") + if token: + if on_token is not None: + on_token(token) + content_chunks.append(token) + + # Tool call deltas. + for tc_delta in delta.get("tool_calls", []): + idx = tc_delta["index"] + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": tc_delta.get("id", ""), + "type": "function", + "function": {"name": "", "arguments": ""}, + } + tc = tool_calls_by_index[idx] + func_delta = tc_delta.get("function", {}) + if func_delta.get("name"): + tc["function"]["name"] += func_delta["name"] + if func_delta.get("arguments"): + tc["function"]["arguments"] += func_delta["arguments"] + + result: Message = {"role": "assistant", "content": "".join(content_chunks)} + if reasoning_chunks: + result["reasoning"] = "".join(reasoning_chunks) + if tool_calls_by_index: + result["tool_calls"] = [ + tool_calls_by_index[i] for i in sorted(tool_calls_by_index) + ] + if usage is not None: + result["usage"] = usage + return result + + +def chat( + messages: list[Message], + model: str | None = None, + temperature: float = 0.7, + max_tokens: int = 2048, + reasoning_effort: str | None = None, + tools: list[dict] | None = None, + log: Callable[[str], None] | None = None, + session: Any = None, +) -> Message: + """Send a chat completion request. Returns the full response message dict. + + Handles payload construction, retries on transient errors, usage + recording, and streaming log output. + """ + model = model or DEFAULT_MODEL + if not API_KEY: + raise RuntimeError( + "OPENROUTER_API_KEY not set. Export it before running the LLM loop." + ) + + if log is None: + log = lambda _: None + + payload: dict[str, Any] = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + if reasoning_effort is not None: + payload["reasoning"] = {"enabled": True, "effort": reasoning_effort} + if tools: + payload["tools"] = [dict(t) for t in tools] + + # Streaming callbacks that manage [thinking]/[/thinking] delimiters. + in_reasoning = False + + def on_thinking(text: str) -> None: + nonlocal in_reasoning + if not in_reasoning: + log("\n [thinking] ") + in_reasoning = True + log(text) + + def on_token(text: str) -> None: + nonlocal in_reasoning + if in_reasoning: + log("\n [/thinking]\n") + in_reasoning = False + + result: Message = _with_retry(_stream_request, log, payload, on_token, on_thinking) + if in_reasoning: + log("\n [/thinking]\n") + + if "usage" in result: + _record_usage(model, result["usage"]) + pt = result["usage"].get("prompt_tokens", "?") + ct = result["usage"].get("completion_tokens", "?") + log(f"\n [tokens] prompt={pt} completion={ct}\n") + return result diff --git a/conductor/tools.py b/conductor/tools.py new file mode 100644 index 000000000..b0f87b8c0 --- /dev/null +++ b/conductor/tools.py @@ -0,0 +1,87 @@ +"""Tool registry for LLM function calling.""" + +import json +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Param: + """Single tool parameter specification.""" + + name: str + description: str + type: str = "string" + required: bool = True + items: dict[str, str] | None = None + + +@dataclass +class ToolDef: + """Tool definition with schema and handler.""" + + name: str + description: str + params: list[Param] = field(default_factory=list) + func: Callable[..., str] | None = None + + def to_api(self) -> dict[str, Any]: + """Convert to OpenAI function-calling format.""" + properties: dict[str, Any] = {} + for p in self.params: + prop: dict[str, Any] = {"type": p.type, "description": p.description} + if p.items is not None: + prop["items"] = p.items + properties[p.name] = prop + required = [p.name for p in self.params if p.required] + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } + + +class ToolRegistry: + """Central registry for tool definitions and dispatch.""" + + def __init__(self) -> None: + self._tools: dict[str, ToolDef] = {} + + def add( + self, + name: str, + description: str, + params: list[Param], + func: Callable[..., str], + ) -> None: + """Register a tool with its definition and handler.""" + self._tools[name] = ToolDef( + name=name, + description=description, + params=params, + func=func, + ) + + def definitions(self) -> list[dict[str, Any]]: + """Return all tool definitions in OpenAI API format.""" + return [tool.to_api() for tool in self._tools.values()] + + def execute(self, name: str, arguments: str) -> str: + """Execute a tool by name with JSON-encoded arguments.""" + name = name.strip() + tool = self._tools.get(name) + if tool is None or tool.func is None: + return json.dumps({"error": f"Unknown tool: {name}"}) + try: + kwargs: dict[str, Any] = json.loads(arguments) + return tool.func(**kwargs) + except Exception as e: + return json.dumps({"error": str(e)}) diff --git a/wave_lang/kernel/wave/.gitignore b/wave_lang/kernel/wave/.gitignore index fd68a7a03..6d973f16e 100644 --- a/wave_lang/kernel/wave/.gitignore +++ b/wave_lang/kernel/wave/.gitignore @@ -11,7 +11,7 @@ dist/ downloads/ eggs/ .eggs/ -lib/ +# lib/ lib64/ parts/ sdist/ diff --git a/wave_lang/kernel/wave/asm/wave_asm/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/CMakeLists.txt index 3ea5c5316..bcffb1cbe 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/CMakeLists.txt @@ -51,6 +51,7 @@ add_definitions(${LLVM_DEFINITIONS}) set(WAVEASM_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(WAVEASM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(WAVEASM_TOOLS_DIR ${CMAKE_CURRENT_BINARY_DIR}/tools/waveasm-translate) +set(WAVEASM_CONDUCTOR_DIR ${CMAKE_CURRENT_BINARY_DIR}/tools/waveasm-conductor) #------------------------------------------------------------------------------- # Directory Configuration diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h new file mode 100644 index 000000000..69d9b99a3 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h @@ -0,0 +1,66 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef WAVEASM_TRANSFORMS_APPLYMOVES_H +#define WAVEASM_TRANSFORMS_APPLYMOVES_H + +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include +#include + +namespace waveasm { + +/// Move an op before a reference op. +struct MoveBefore { + std::string tag; + std::string refTag; +}; + +/// Move an op after a reference op. +struct MoveAfter { + std::string tag; + std::string refTag; +}; + +/// Swap two ops. +struct Swap { + std::string tag1; + std::string tag2; +}; + +using Command = std::variant; + +/// Result of applying a sequence of move commands to a module. +struct MoveResult { + bool success; + std::string error; + unsigned failedCommand; +}; + +/// Result of parsing CONDUCTOR commands from raw text. +struct ParseResult { + bool success; + std::string error; + unsigned failedLine; // 0-based index of the failing CONDUCTOR line. + llvm::SmallVector commands; +}; + +/// Parse CONDUCTOR command lines from raw file text. +/// Scans for lines matching `// CONDUCTOR: ` and parses them +/// into typed Command structs. +ParseResult parseConductorCommands(llvm::StringRef text); + +/// Apply a sequence of parsed Conductor commands to a tagged module. +/// +/// The module must already have NameLoc tags attached (via TagInstructions). +/// Returns a MoveResult indicating success or the first error encountered. +MoveResult applyMoves(mlir::ModuleOp module, llvm::ArrayRef commands); + +} // namespace waveasm + +#endif // WAVEASM_TRANSFORMS_APPLYMOVES_H diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.h b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.h index 2d41e1d95..9d965561c 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.h +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.h @@ -57,6 +57,9 @@ std::unique_ptr createWAVEASMPeepholePass(); /// Create the memory offset optimization pass std::unique_ptr createWAVEASMMemoryOffsetOptPass(); +/// Create the instruction tagging pass (Conductor infrastructure). +std::unique_ptr createWAVEASMTagInstructionsPass(); + //===----------------------------------------------------------------------===// // Pass Registration //===----------------------------------------------------------------------===// diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.td b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.td index afaf96af1..9c1ec0bd9 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.td +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.td @@ -250,4 +250,27 @@ def WAVEASMEmitAssembly : Pass<"waveasm-emit-assembly"> { let dependentDialects = ["::waveasm::WaveASMDialect"]; } +//===----------------------------------------------------------------------===// +// Instruction Tagging Pass (Conductor) +//===----------------------------------------------------------------------===// + +def WAVEASMTagInstructions : Pass<"waveasm-tag-instructions", "mlir::ModuleOp"> { + let summary = "Attach stable NameLoc tags to WaveASM instructions"; + let description = [{ + Tags each WaveASM operation with a NameLoc of the form + "_" where N is a per-op-kind counter. For example: + buffer_load_dwordx4_0, v_mfma_f32_16x16x16_f16_0, ds_read_b64_2. + + Tags are stable across scheduling rounds — they follow the operation, + not the position. This is the first step of the Conductor scheduling + infrastructure. + }]; + + let statistics = [ + Statistic<"numOpsTagged", "Operations tagged", "count"> + ]; + + let dependentDialects = ["::waveasm::WaveASMDialect"]; +} + #endif // WaveASM_TRANSFORMS_PASSES diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp new file mode 100644 index 000000000..513b70cb8 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp @@ -0,0 +1,256 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// ApplyMoves — parse and apply Conductor move/swap commands on tagged IR. +//===----------------------------------------------------------------------===// + +#include "waveasm/Transforms/ApplyMoves.h" + +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Location.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "waveasm-apply-moves" + +using namespace mlir; + +namespace { + +/// Op mnemonics that must never be moved. +bool isPinned(Operation *op) { + StringRef name = op->getName().stripDialect(); + return name == "condition" || name == "s_endpgm" || name == "s_barrier"; +} + +/// Build a map from NameLoc tag string to Operation*. +llvm::StringMap buildTagMap(ModuleOp module) { + llvm::StringMap map; + module.walk([&](Operation *op) { + if (auto nameLoc = dyn_cast(op->getLoc())) + map[nameLoc.getName().strref()] = op; + }); + return map; +} + +/// Validate that an op can be moved (not pinned). +std::string validateMovable(Operation *op, StringRef tag) { + if (isPinned(op)) + return ("cannot move pinned op '" + tag + "'").str(); + return ""; +} + +/// Check that two ops are in the same block. +std::string validateSameBlock(Operation *a, StringRef tagA, Operation *b, + StringRef tagB) { + if (a->getBlock() != b->getBlock()) + return ("'" + tagA + "' and '" + tagB + "' are in different blocks").str(); + return ""; +} + +/// Get a human-readable name for an operation (tag or mnemonic). +std::string opName(Operation *op) { + if (auto nameLoc = dyn_cast(op->getLoc())) + return nameLoc.getName().str(); + return op->getName().getStringRef().str(); +} + +/// Check SSA dominance for a single op after it has been moved. +/// Verifies: (1) all operands are defined above the op, +/// (2) all results are defined before their users. +std::string checkDominance(Operation *op) { + DominanceInfo domInfo(op->getParentOfType()); + + // Check that every operand still dominates this op. + for (Value operand : op->getOperands()) { + if (!domInfo.properlyDominates(operand, op)) { + std::string defName = "block-arg"; + if (auto *defOp = operand.getDefiningOp()) + defName = opName(defOp); + return "moving '" + opName(op) + "' breaks dominance: operand from '" + + defName + "' no longer defined before use"; + } + } + + // Check that every user of this op's results is still dominated. + for (Value result : op->getResults()) { + for (Operation *user : result.getUsers()) { + if (!domInfo.properlyDominates(op, user)) { + return "moving '" + opName(op) + + "' breaks dominance: result used by '" + opName(user) + + "' which now appears before the definition"; + } + } + } + + return ""; +} + +} // namespace + +namespace waveasm { + +ParseResult parseConductorCommands(llvm::StringRef text) { + ParseResult result; + result.success = true; + result.failedLine = 0; + + llvm::SmallVector lines; + text.split(lines, '\n'); + + unsigned cmdIdx = 0; + for (StringRef line : lines) { + StringRef trimmed = line.ltrim(); + if (!trimmed.starts_with("// CONDUCTOR:")) + continue; + StringRef raw = trimmed.drop_front(strlen("// CONDUCTOR:")).trim(); + if (raw.empty()) + continue; + + if (raw.starts_with("move ")) { + StringRef rest = raw.drop_front(strlen("move ")); + auto [tag, rest2] = rest.split(' '); + auto [direction, refTag] = rest2.split(' '); + + if (tag.empty() || refTag.empty() || direction.empty()) { + result.success = false; + result.error = ("malformed move command: '" + raw + "'").str(); + result.failedLine = cmdIdx; + return result; + } + + if (direction == "before") { + result.commands.push_back(MoveBefore{tag.str(), refTag.str()}); + } else if (direction == "after") { + result.commands.push_back(MoveAfter{tag.str(), refTag.str()}); + } else { + result.success = false; + result.error = + ("expected 'after' or 'before', got '" + direction + "'").str(); + result.failedLine = cmdIdx; + return result; + } + + } else if (raw.starts_with("swap ")) { + StringRef rest = raw.drop_front(strlen("swap ")); + auto [tag1, tag2] = rest.split(' '); + + if (tag1.empty() || tag2.empty()) { + result.success = false; + result.error = ("malformed swap command: '" + raw + "'").str(); + result.failedLine = cmdIdx; + return result; + } + + result.commands.push_back(Swap{tag1.str(), tag2.str()}); + + } else { + result.success = false; + result.error = ("unknown command: '" + raw + "'").str(); + result.failedLine = cmdIdx; + return result; + } + + ++cmdIdx; + } + + return result; +} + +/// Resolve two tags, validate movability and same-block constraint. +/// On success sets op1/op2 and returns empty string. +static std::string +resolveAndValidate(const llvm::StringMap &tagMap, StringRef tag, + StringRef ref, bool checkRefMovable, Operation *&op1, + Operation *&op2) { + auto it1 = tagMap.find(tag); + if (it1 == tagMap.end()) + return ("unknown tag '" + tag + "'").str(); + auto it2 = tagMap.find(ref); + if (it2 == tagMap.end()) + return ("unknown tag '" + ref + "'").str(); + + op1 = it1->second; + op2 = it2->second; + + std::string err = validateMovable(op1, tag); + if (!err.empty()) + return err; + if (checkRefMovable) { + err = validateMovable(op2, ref); + if (!err.empty()) + return err; + } + return validateSameBlock(op1, tag, op2, ref); +} + +MoveResult applyMoves(ModuleOp module, llvm::ArrayRef commands) { + auto tagMap = buildTagMap(module); + + for (auto [idx, cmd] : llvm::enumerate(commands)) { + auto fail = [&](const std::string &msg) -> MoveResult { + return {false, msg, static_cast(idx)}; + }; + + Operation *op1 = nullptr, *op2 = nullptr; + + if (auto *move = std::get_if(&cmd)) { + std::string err = + resolveAndValidate(tagMap, move->tag, move->refTag, false, op1, op2); + if (!err.empty()) + return fail(err); + op1->moveBefore(op2); + LDBG() << "move " << move->tag << " before " << move->refTag; + err = checkDominance(op1); + if (!err.empty()) + return fail(err); + + } else if (auto *move = std::get_if(&cmd)) { + std::string err = + resolveAndValidate(tagMap, move->tag, move->refTag, false, op1, op2); + if (!err.empty()) + return fail(err); + op1->moveAfter(op2); + LDBG() << "move " << move->tag << " after " << move->refTag; + err = checkDominance(op1); + if (!err.empty()) + return fail(err); + + } else if (auto *swap = std::get_if(&cmd)) { + std::string err = + resolveAndValidate(tagMap, swap->tag1, swap->tag2, true, op1, op2); + if (!err.empty()) + return fail(err); + + // Swap by considering adjacency cases. + Operation *op1Next = op1->getNextNode(); + if (op1Next == op2) { + op1->moveAfter(op2); + } else if (op2->getNextNode() == op1) { + op2->moveAfter(op1); + } else { + op1->moveAfter(op2); + if (op1Next) + op2->moveBefore(op1Next); + else + op2->moveBefore(op2->getBlock(), op2->getBlock()->end()); + } + LDBG() << "swap " << swap->tag1 << " " << swap->tag2; + err = checkDominance(op1); + if (!err.empty()) + return fail(err); + err = checkDominance(op2); + if (!err.empty()) + return fail(err); + } + } + + return {true, "", 0}; +} + +} // namespace waveasm diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt index 5dcd41422..7094db564 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt @@ -27,6 +27,8 @@ add_mlir_dialect_library(MLIRWaveASMTransforms ScopedCSE.cpp Peephole.cpp MemoryOffsetOptimization.cpp + TagInstructions.cpp + ApplyMoves.cpp ${HANDLERS_FULL_PATHS} ADDITIONAL_HEADER_DIRS diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TagInstructions.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TagInstructions.cpp new file mode 100644 index 000000000..d8b875d7b --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TagInstructions.cpp @@ -0,0 +1,67 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// Tag Instructions Pass - Attach stable NameLoc tags to WaveASM ops. +// +// Each operation gets a tag like loc("buffer_load_dwordx4_0"), +// loc("v_mfma_f32_16x16x16_f16_3"), etc. Tags are stable across +// scheduling rounds (they follow the operation, not the position). +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "waveasm-tag-instructions" + +using namespace mlir; +using namespace waveasm; + +namespace { + +struct TagInstructionsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TagInstructionsPass) + + StringRef getArgument() const override { return "waveasm-tag-instructions"; } + + StringRef getDescription() const override { + return "Attach stable NameLoc tags to WaveASM instructions"; + } + + void runOnOperation() override { + unsigned total = 0; + MLIRContext *ctx = &getContext(); + llvm::DenseMap counters; + + getOperation().walk([&](Operation *op) { + if (op->getDialect() && op->getDialect()->getNamespace() == "waveasm") { + StringRef opName = op->getName().stripDialect(); + unsigned idx = counters[opName]++; + std::string tag = (opName + "_" + llvm::Twine(idx)).str(); + op->setLoc(NameLoc::get(StringAttr::get(ctx, tag))); + ++total; + } + }); + + LDBG() << "tagged " << total << " operations"; + } +}; + +} // namespace + +namespace waveasm { + +std::unique_ptr createWAVEASMTagInstructionsPass() { + return std::make_unique(); +} + +} // namespace waveasm diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/test/CMakeLists.txt index 2d78933d0..941777bf8 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/test/CMakeLists.txt @@ -14,6 +14,7 @@ configure_lit_site_cfg( set(WAVEASM_TEST_DEPENDS waveasm-translate + waveasm-conductor ) add_lit_testsuite(check-waveasm "Running the waveasm regression tests" diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir new file mode 100644 index 000000000..7712080ee --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir @@ -0,0 +1,19 @@ +// RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s + +// CONDUCTOR: move v_add_u32_0 after v_add_u32_1 + +// Two independent adds — reordering is safe. + +// CHECK-LABEL: waveasm.program @test_move_after +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") +waveasm.program @test_move_after target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + %c1 = waveasm.constant 1 : !waveasm.imm<1> + + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + %a1 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg loc("v_add_u32_1") + + waveasm.s_endpgm loc("s_endpgm_0") +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir new file mode 100644 index 000000000..da6d452bc --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir @@ -0,0 +1,18 @@ +// RUN: not waveasm-conductor %s 2>&1 | FileCheck %s + +// CONDUCTOR: move v_add_u32_0 after v_add_u32_1 + +// v_add_u32_1 uses the result of v_add_u32_0, so moving _0 after _1 +// breaks dominance. + +// CHECK: moving 'v_add_u32_0' breaks dominance: result used by 'v_add_u32_1' which now appears before the definition + +waveasm.program @test_dominance target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_1") + + waveasm.s_endpgm loc("s_endpgm_0") +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir new file mode 100644 index 000000000..d0620cb6e --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir @@ -0,0 +1,13 @@ +// RUN: not waveasm-conductor %s 2>&1 | FileCheck %s + +// CONDUCTOR: move s_endpgm_0 before v_add_u32_0 + + +// CHECK: conductor: command 0: cannot move pinned op 's_endpgm_0' + +waveasm.program @test_pinned target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + waveasm.s_endpgm loc("s_endpgm_0") +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir new file mode 100644 index 000000000..9412e59f8 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir @@ -0,0 +1,13 @@ +// RUN: not waveasm-conductor %s 2>&1 | FileCheck %s + +// CONDUCTOR: move nonexistent_tag_42 after v_add_u32_0 + + +// CHECK: conductor: command 0: unknown tag 'nonexistent_tag_42' + +waveasm.program @test_unknown_tag target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + waveasm.s_endpgm +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir new file mode 100644 index 000000000..fb54ecd2f --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir @@ -0,0 +1,20 @@ +// RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s + +// CONDUCTOR: swap v_add_u32_0 v_lshlrev_b32_0 + +// Three independent ops (all read from %v0/%c4) — swap is safe. + +// CHECK-LABEL: waveasm.program @test_swap +// CHECK: waveasm.v_lshlrev_b32{{.*}}loc("v_lshlrev_b32_0") +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") +waveasm.program @test_swap target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + %a1 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_1") + %s0 = waveasm.v_lshlrev_b32 %c4, %v0 : !waveasm.imm<4>, !waveasm.pvreg<0> -> !waveasm.vreg loc("v_lshlrev_b32_0") + + waveasm.s_endpgm loc("s_endpgm_0") +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir new file mode 100644 index 000000000..d38773df1 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir @@ -0,0 +1,19 @@ +// RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s + +// CONDUCTOR: move v_add_u32_1 before v_add_u32_0 + +// Two independent adds from the same inputs — reordering is safe. + +// CHECK-LABEL: waveasm.program @test_move_before +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") +waveasm.program @test_move_before target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + %c1 = waveasm.constant 1 : !waveasm.imm<1> + + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + %a1 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg loc("v_add_u32_1") + + waveasm.s_endpgm loc("s_endpgm_0") +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/tag-instructions.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/tag-instructions.mlir new file mode 100644 index 000000000..c19183604 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/tag-instructions.mlir @@ -0,0 +1,48 @@ +// RUN: waveasm-translate --waveasm-tag-instructions --print-debug-locs-inline %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test: Each op kind gets its own counter. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @per_kind_counters +waveasm.program @per_kind_counters target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + + // Two adds should get v_add_u32_0 and v_add_u32_1. + // CHECK: waveasm.v_add_u32 {{.*}} loc("v_add_u32_0") + // CHECK: waveasm.v_add_u32 {{.*}} loc("v_add_u32_1") + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg + + // A shift gets its own counter starting at 0. + // CHECK: waveasm.v_lshlrev_b32 {{.*}} loc("v_lshlrev_b32_0") + %s0 = waveasm.v_lshlrev_b32 %c4, %a1 : !waveasm.imm<4>, !waveasm.vreg -> !waveasm.vreg + + // Load and store have independent counters. + // CHECK: waveasm.buffer_load_dwordx4 {{.*}} loc("buffer_load_dwordx4_0") + // CHECK: waveasm.buffer_store_dword {{.*}} loc("buffer_store_dword_0") + %ld = waveasm.buffer_load_dwordx4 %srd, %s0 : !waveasm.psreg<0, 4>, !waveasm.vreg -> !waveasm.vreg<4> + waveasm.buffer_store_dword %a1, %srd, %s0 : !waveasm.vreg, !waveasm.psreg<0, 4>, !waveasm.vreg + + // CHECK: waveasm.s_endpgm loc("s_endpgm_0") + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// Test: Counters reset per module, not per program. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @second_program +waveasm.program @second_program target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c1 = waveasm.constant 1 : !waveasm.imm<1> + + // Counter continues from first program (module-wide walk). + // CHECK: waveasm.v_add_u32 {{.*}} loc("v_add_u32_2") + %a0 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg + + // CHECK: waveasm.s_endpgm loc("s_endpgm_1") + waveasm.s_endpgm +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/waveasm_e2e.py b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/waveasm_e2e.py index 035bf403b..92071487e 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/waveasm_e2e.py +++ b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/waveasm_e2e.py @@ -125,13 +125,6 @@ def get_waveasm_translate_path() -> Path: def get_amdclang_path() -> str: """Get path to amdclang++ for assembly compilation.""" - # Check ROCM_PATH first - rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") - amdclang = os.path.join(rocm_path, "bin", "amdclang++") - - if os.path.exists(amdclang): - return amdclang - # Try to find it in PATH try: result = subprocess.run(["which", "amdclang++"], capture_output=True, text=True) @@ -140,6 +133,13 @@ def get_amdclang_path() -> str: except Exception: pass + # Do not check ROCM_PATH first, it breaks with the Rock. + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + amdclang = os.path.join(rocm_path, "bin", "amdclang++") + + if os.path.exists(amdclang): + return amdclang + raise FileNotFoundError( "amdclang++ not found. Ensure ROCm is installed and in PATH." ) diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py b/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py index 2ab708d36..12c3b8ceb 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py +++ b/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py @@ -33,8 +33,12 @@ llvm_config.use_default_substitutions() # Add tools to the path -tool_dirs = [config.waveasm_tools_dir, config.llvm_tools_dir] -tools = ["waveasm-translate", "FileCheck", "count", "not"] +tool_dirs = [ + config.waveasm_tools_dir, + config.waveasm_conductor_dir, + config.llvm_tools_dir, +] +tools = ["waveasm-translate", "waveasm-conductor", "FileCheck", "count", "not"] llvm_config.add_tool_substitutions(tools, tool_dirs) # ROCm toolchain detection for integration tests diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/lit.site.cfg.py.in b/wave_lang/kernel/wave/asm/wave_asm/test/lit.site.cfg.py.in index 83304e850..9c6bd1ee0 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/lit.site.cfg.py.in +++ b/wave_lang/kernel/wave/asm/wave_asm/test/lit.site.cfg.py.in @@ -20,6 +20,7 @@ config.llvm_shlib_ext = "@SHLIBEXT@" config.waveasm_obj_root = "@WAVEASM_BINARY_DIR@" config.waveasm_src_root = "@WAVEASM_SOURCE_DIR@" config.waveasm_tools_dir = "@WAVEASM_TOOLS_DIR@" +config.waveasm_conductor_dir = "@WAVEASM_CONDUCTOR_DIR@" import lit.llvm lit.llvm.initialize(lit_config, config) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/tools/CMakeLists.txt index b8c29c2a5..c2ed906e7 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/CMakeLists.txt @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(waveasm-translate) +add_subdirectory(waveasm-conductor) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/CMakeLists.txt new file mode 100644 index 000000000..2ddc6b4ef --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_llvm_executable(waveasm-conductor + waveasm-conductor.cpp +) + +llvm_update_compile_flags(waveasm-conductor) + +target_link_libraries(waveasm-conductor + PRIVATE + MLIRWaveASMDialect + MLIRWaveASMTransforms + MLIRParser + MLIRPass + MLIRSupport + MLIRIR + LLVMSupport +) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp new file mode 100644 index 000000000..8c0d6ccf2 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -0,0 +1,147 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// waveasm-conductor: CLI tool for applying Conductor move commands. +// +// Reads tagged WaveASM IR with embedded // CONDUCTOR: comments, +// parses and applies the move/swap commands, then prints the result. +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Transforms/ApplyMoves.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Command Line Options +//===----------------------------------------------------------------------===// + +static llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +static llvm::cl::opt + outputFilename("o", llvm::cl::desc("Output filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("-")); + +static llvm::cl::opt printDebugLocsInline( + "print-debug-locs-inline", + llvm::cl::desc("Print location information inline (pretty form)"), + llvm::cl::init(false)); + +static llvm::cl::opt dumpIROnFailure( + "dump-ir-on-failure", + llvm::cl::desc("Dump IR to stderr on move or verify failure"), + llvm::cl::init(false)); + +static llvm::cl::opt useNameLocAsPrefix( + "use-nameloc-as-prefix", + llvm::cl::desc("Print SSA IDs using NameLocs as prefixes"), + llvm::cl::init(false)); + +//===----------------------------------------------------------------------===// +// Main Function +//===----------------------------------------------------------------------===// + +int main(int argc, char **argv) { + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv, + "WAVEASM Conductor move executor\n"); + + // Read the raw input file to extract CONDUCTOR commands before parsing. + auto inputFileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (auto ec = inputFileOrErr.getError()) { + llvm::errs() << "Error reading input file: " << ec.message() << "\n"; + return 1; + } + + llvm::StringRef rawText = (*inputFileOrErr)->getBuffer(); + auto parseResult = waveasm::parseConductorCommands(rawText); + + if (!parseResult.success) { + llvm::errs() << "conductor: parse error at command " + << parseResult.failedLine << ": " << parseResult.error << "\n"; + return 1; + } + if (parseResult.commands.empty()) { + llvm::errs() << "No CONDUCTOR commands found in input\n"; + return 1; + } + + // Set up MLIR context and parse the module. + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.loadAllAvailableDialects(); + context.allowUnregisteredDialects(); + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*inputFileOrErr), llvm::SMLoc()); + + OwningOpRef module = parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Failed to parse input file\n"; + return 1; + } + + // Apply the move commands (IR is expected to have NameLoc tags already). + waveasm::MoveResult result = + waveasm::applyMoves(*module, parseResult.commands); + if (!result.success) { + llvm::errs() << "conductor: command " << result.failedCommand << ": " + << result.error << "\n"; + if (dumpIROnFailure) { + llvm::errs() << "--- IR after partial moves ---\n"; + module->print(llvm::errs()); + llvm::errs() << "--- end IR ---\n"; + } + return 1; + } + + // Verify the module after moves (catches broken dominance, etc.). + if (failed(mlir::verify(*module))) { + llvm::errs() << "conductor: verification failed after applying moves\n"; + if (dumpIROnFailure) { + llvm::errs() << "--- IR at verification failure ---\n"; + module->print(llvm::errs()); + llvm::errs() << "--- end IR ---\n"; + } + return 1; + } + + // Print the result. + std::error_code ec; + llvm::raw_fd_ostream outputStream(outputFilename, ec); + if (ec) { + llvm::errs() << "Error opening output file: " << ec.message() << "\n"; + return 1; + } + + OpPrintingFlags flags; + if (printDebugLocsInline) { + flags.enableDebugInfo(/*prettyForm=*/true); + flags.useLocalScope(); + } + if (useNameLocAsPrefix) + flags.printNameLocAsPrefix(); + module->print(outputStream, flags); + + return 0; +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp index c8a8578a4..3b30845e8 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp @@ -100,6 +100,26 @@ static llvm::cl::opt llvm::cl::desc("Emit AMDGCN assembly instead of MLIR"), llvm::cl::init(false)); +static llvm::cl::opt runTagInstructions( + "waveasm-tag-instructions", + llvm::cl::desc("Attach stable NameLoc tags to WaveASM instructions"), + llvm::cl::init(false)); + +static llvm::cl::opt + printDebugLocs("print-debug-locs", + llvm::cl::desc("Print location information in MLIR output"), + llvm::cl::init(false)); + +static llvm::cl::opt printDebugLocsInline( + "print-debug-locs-inline", + llvm::cl::desc("Print location information inline (pretty form)"), + llvm::cl::init(false)); + +static llvm::cl::opt useNameLocAsPrefix( + "use-nameloc-as-prefix", + llvm::cl::desc("Print SSA IDs using NameLocs as prefixes"), + llvm::cl::init(false)); + static llvm::cl::opt runPreTranslationCSE( "mlir-cse", llvm::cl::desc("Run MLIR CSE before translation (reduces redundant index " @@ -249,6 +269,11 @@ int main(int argc, char **argv) { } } + // Instruction tagging for Conductor scheduling infrastructure. + if (runTagInstructions) { + pm.addPass(waveasm::createWAVEASMTagInstructionsPass()); + } + // Register allocation must run before waitcnt/hazard so that those passes // see the final register assignments. Matches compare_backends.py order: // LinearScan -> Waitcnt -> Hazard. @@ -330,8 +355,17 @@ int main(int argc, char **argv) { return success ? 0 : 1; } - // Print the translated module (MLIR format) - module->print(outputStream); + // Print the translated module (MLIR format). + OpPrintingFlags flags; + if (printDebugLocsInline) { + flags.enableDebugInfo(/*prettyForm=*/true); + flags.useLocalScope(); + } else if (printDebugLocs) { + flags.enableDebugInfo(); + } + if (useNameLocAsPrefix) + flags.printNameLocAsPrefix(); + module->print(outputStream, flags); return 0; }