From 3cb0568ecd7dabb18c417fe4407cbc30597826cd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 18:20:09 +0100 Subject: [PATCH 01/49] draft Signed-off-by: Ivan Butygin --- .../include/waveasm/CONDUCTOR_DESIGN.md | 335 ++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md new file mode 100644 index 000000000..0a58a33f3 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md @@ -0,0 +1,335 @@ +# Conductor: LLM-Guided Instruction Scheduling for WaveASM + +## Context + +The WaveASM C++ backend has **no instruction scheduling pass**. Instructions are +emitted in the exact order they arrive from the upstream Python frontend. While +the Python frontend does tile-level scheduling (software pipelining, +double-buffering, stage assignment), it does not optimize **within** instruction +groups. This leaves performance on the table: + +- LDS reads (20-cycle latency) are clustered together, then MFMAs (32+ cycles) + are clustered together. Interleaving them would hide LDS latency behind MFMA + execution. +- Global loads (100-cycle latency) are not interleaved with independent compute. +- Register pressure from the instruction order can cause the linear-scan + allocator to fail (it has no spilling). + +The target latency data (`getMFMALatency()`, `getGlobalLoadLatency()`, +`getLDSLoadLatency()`) already exists in `WaveASMAttrs.td` but is **unused by +any optimization pass**. + +## Core Idea + +No intermediate DSL. The LLM sees **real WaveASM MLIR textual IR** with named +location tags on each instruction. It issues **move commands** ("move A after +B"). A deterministic engine validates structural invariants (dominance, +dependency), applies the moves, runs the rest of the WaveASM MLIR compilation +pipeline, and reports **metrics** back to the LLM. The loop repeats until the +LLM is satisfied or a budget is exhausted. + +``` +WaveASM IR (post-CSE/peephole, virtual registers) + | + v +[Tag] ── attach NameLoc("I0"), ("I1"), ... to each op + | + v +[Print] ── module.print() ──> textual MLIR + header with target/latency info + | + v +[LLM] ── reads IR, issues move commands ──> "move I5 after I2" + | "move I8 before I12" + | "swap I3 I7" + v +[Executor] ── parse commands + ── validate (dominance, pinned ops) + ── apply via Operation::moveBefore() + ── anchor pseudo-ops (pack/extract/const) + | + v +[Compile] ── run LinearScan + InsertWaitcnt + HazardMitigation on clone + | + v +[Metrics] ── collect from WaveASM pipeline passes ──> report to LLM + | + v +[LLM] ── sees metrics, decides: more moves or done + | + v +(repeat up to N rounds, then emit final assembly) +``` + +The key insight: **the LLM works on the actual IR**, not an abstraction. It can +see every instruction, every SSA value, every type. The system handles all the +mechanical work (validation, compilation, metric collection) while the LLM +handles the creative work (what to move where). + +## Instruction Tagging + +Before presenting IR to the LLM, each operation in schedulable regions gets a +stable `NameLoc` tag: + +```cpp +// During tagging pass. +int counter = 0; +for (auto &op : block.getOperations()) { + auto tag = NameLoc::get(builder.getStringAttr("I" + std::to_string(counter++))); + op.setLoc(tag); +} +``` + +In textual MLIR this appears as: + +```mlir +%0 = waveasm.buffer_load_dwordx4 %srd, %off offset:0 + : !waveasm.sreg<4>, !waveasm.vreg -> !waveasm.vreg<4> loc("I0") +%1 = waveasm.buffer_load_dwordx4 %srd, %off offset:64 + : !waveasm.sreg<4>, !waveasm.vreg -> !waveasm.vreg<4> loc("I1") +%2 = waveasm.ds_write_b128 %0, %addr + : !waveasm.vreg<4>, !waveasm.vreg loc("I2") +... +%10 = waveasm.v_mfma_f32_16x16x16_f16 %a, %b, %acc + : !waveasm.vreg, !waveasm.vreg, !waveasm.vreg<4> + -> !waveasm.vreg<4> loc("I10") +``` + +Tags are stable across rounds — they follow the operation, not the position. + +## LLM Input + +Each round, the LLM receives: + +``` +=== WaveASM Scheduling Round {N} === +TARGET: gfx942 (wave64, 512 vgpr, 106 sgpr, 512 agpr) +LATENCY: vmem=100, lds=20, mfma_16x16=32, mfma_32x32=64 + +--- IR (loop body) --- +{textual MLIR of the loop body, with loc("Ixx") tags} + +--- Metrics (from previous round, or baseline) --- +peak_vgpr: 180 +peak_sgpr: 42 +peak_agpr: 128 +nops_inserted: 3 +waitcnts_inserted: 12 +total_instructions: 87 + +--- Error from previous round (if any) --- +(none, or e.g.: + "Applied successfully: move I5 after I1, move I8 before I3 + Failed: swap I6 I9 — would break dominance of %2 (defined by I6, used by I7) + All moves reverted.") + +GOAL: Minimize register pressure and hide memory latency. +Respond with move commands, one per line. +``` + +## LLM Output (Move Commands) + +The LLM responds with simple text commands: + +``` +move I5 after I1 +move I8 before I3 +swap I6 I9 +``` + +### Command Set + +| Command | Semantics | +|---------|-----------| +| `move Ix after Iy` | Move op tagged Ix to immediately after op tagged Iy. | +| `move Ix before Iy` | Move op tagged Ix to immediately before op tagged Iy. | +| `swap Ix Iy` | Exchange positions of ops Ix and Iy. | +| `done` | LLM is satisfied with current schedule. | + +Three move commands + a stop signal. No DSL to learn. + +## Validation + +Move commands are applied sequentially. Each is validated **before** application: + +1. **Tag resolution**: Both tags must exist. +2. **Pinned ops**: `ConditionOp` (loop terminator), `s_barrier`, `s_endpgm` + cannot be moved. +3. **Dominance check (pre-flight)**: Simulate the move and verify every use of + every SSA value defined by the moved op is still dominated by its def. Also + check that all operands of the moved op are still dominated. +4. **Region boundary**: Cannot move ops across region boundaries (into/out of + `LoopOp` or `IfOp`). + +On the first invalid move, the system **aborts the entire round**: all moves +applied so far in this round are reverted, and the LLM receives a report +showing which moves succeeded before the failure and which move was invalid +(with the reason). For example: + +``` +--- Error --- +Applied successfully: move I5 after I1, move I8 before I3 +Failed: swap I6 I9 — would break dominance of %2 (defined by I6, used by I7) +All moves reverted. +``` + +This gives the LLM a clear signal about what went wrong and a complete picture +of which commands were valid up to the failure point. + +## Metrics Collection + +After applying a round of moves, the system runs the downstream WaveASM MLIR +pipeline on a **clone** of the IR and collects metrics: + +```cpp +struct SchedulingMetrics { + int64_t peakVGPRs; + int64_t peakSGPRs; + int64_t peakAGPRs; + int64_t nopsInserted; // from HazardMitigation. + int64_t waitcntsInserted; // from InsertWaitcnt. + int64_t totalInstructions; // post-pipeline instruction count. +}; +``` + +Sources (currently internal to passes, need to be exposed): +- **Peak registers**: `computeMaxPressure()` from `Liveness.h`, or + `AllocationStats` from `LinearScanRegAlloc`. +- **NOP count**: `HazardMitigation` tracks `numNopsInserted` internally. +- **Waitcnt count**: `InsertWaitcnt` tracks tickets internally. +- **Total instructions**: count ops after all passes. + +Additional metrics can be added as needed (e.g. estimated cycle count, LDS bank +conflict potential, critical path length). + +The pipeline runs on a cloned module so the original IR is preserved for the +next round of moves. + +### Optional: GPU Profiling Integration + +When a GPU is available and profiling is enabled, the system can run the +compiled kernel through `rocprof` and feed per-instruction profiling data back +to the LLM. `rocprof` reports hit count and min/max/mean latency for each +assembly instruction. Since the assembly emitter preserves instruction order +and the tags map to assembly locations, profiling data can be correlated back +to the tagged IR. + +The profiling section is appended to the LLM input when available: + +``` +--- Profiling (rocprof, previous run) --- +I0 buffer_load_dwordx4 hits=1024 lat_min=88 lat_mean=102 lat_max=145 +I1 buffer_load_dwordx4 hits=1024 lat_min=90 lat_mean=105 lat_max=148 +I5 ds_read_b128 hits=1024 lat_min=18 lat_mean=21 lat_max=34 +I10 v_mfma_f32_16x16x16 hits=1024 lat_min=32 lat_mean=33 lat_max=35 +I15 s_waitcnt hits=1024 lat_min=0 lat_mean=45 lat_max=120 +``` + +This lets the LLM see actual stall patterns — e.g. a `s_waitcnt` with high +mean latency indicates the preceding loads are not sufficiently hidden. This +data is strictly optional; the system works without it using only the static +pipeline metrics. + +## Pipeline Integration + +``` +TranslateFromMLIR + → ScopedCSE + → Peephole + → MemoryOffsetOpt + → Canonicalizer + ScopedCSE + → [NEW] Conductor (tag + LLM loop + apply final schedule) + → LinearScan + → InsertWaitcnt + → HazardMitigation + → EmitAssembly +``` + +The Conductor pass: +1. Tags all ops with `NameLoc`. +2. Runs baseline metrics (clone → compile → collect). +3. Enters the LLM loop (up to N rounds). +4. Applies the best schedule to the real IR. +5. Hands off to LinearScan. + +## Iterative Loop Detail + +``` +baseline_metrics = compile_and_measure(clone(IR)) +best_metrics = baseline_metrics +best_state = snapshot(IR) + +for round in 1..max_rounds: + text = print_ir(IR) + format_metrics(best_metrics) + format_error(error) + commands = llm.query(text) + + if commands == ["done"]: + break + + error = null + applied = [] + for cmd in commands: + result = validate_and_apply(IR, cmd) + if result.is_error: + error = { applied: applied, failed: cmd, reason: result.message } + restore(IR, best_state) // revert entire round. + break + applied.append(cmd) + + if error: + continue // next round, LLM sees the error report. + + anchor_pseudo_ops(IR) // move pack/extract/const before earliest user. + new_metrics = compile_and_measure(clone(IR)) + + if is_better(new_metrics, best_metrics): + best_metrics = new_metrics + best_state = snapshot(IR) + else: + restore(IR, best_state) // revert if regression. + error = { applied: applied, reason: "round regressed metrics" } + +apply(IR, best_state) +``` + +`is_better()` compares metrics lexicographically: lower peak VGPRs > fewer +waitcnts > fewer nops. This ordering is configurable. + +## Handling Structured Control Flow + +- **`LoopOp` body**: The primary scheduling target. Block arguments are always + available (dominate the entire body). `ConditionOp` is pinned at end. +- **`IfOp`**: Treated as an atomic unit. The LLM can move the entire `IfOp` but + not its internal ops. The tag is on the `IfOp` itself. +- **`ProgramOp` body (prologue)**: Tagged and schedulable as a straight-line + block, but typically less benefit (SRD setup). +- **Nested regions**: Tags are scoped per region. The LLM sees one region at a + time (typically the innermost loop body). + +## Caching + +Final move sequences are cached by SHA-256 of the **tagged IR text** (before +any moves): + +``` +~/.waveasm-cache/conductor/ + .json → { + moves: ["move I5 after I1", "swap I6 I9"], + metrics: { peakVGPRs: 160, nops: 1, waitcnts: 8 }, + rounds: 3, + timestamp: "..." + } +``` + +On cache hit, the moves are replayed directly (still validated) without querying +the LLM. + +## Key Design Properties + +1. **No abstraction gap.** The LLM sees actual IR, not a lossy summary. +2. **Zero new languages.** Move commands are trivial to parse and validate. +3. **Graceful degradation.** Invalid moves are warnings, not errors. +4. **Iterative refinement.** The LLM can observe the effect of its moves and + adjust, rather than producing a one-shot schedule. +5. **Transparent.** Every move is logged. Easy to inspect, replay, and debug. +6. **Extensible metrics.** Adding a new metric is just adding a field to + `SchedulingMetrics` — no DSL changes needed. From 6b4ee2624e956a270582a69b1874f8c4c1eec0a1 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 18:37:08 +0100 Subject: [PATCH 02/49] update Signed-off-by: Ivan Butygin --- .../include/waveasm/CONDUCTOR_DESIGN.md | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md index 0a58a33f3..cf84cb02e 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md @@ -325,11 +325,14 @@ the LLM. ## Key Design Properties -1. **No abstraction gap.** The LLM sees actual IR, not a lossy summary. -2. **Zero new languages.** Move commands are trivial to parse and validate. -3. **Graceful degradation.** Invalid moves are warnings, not errors. -4. **Iterative refinement.** The LLM can observe the effect of its moves and - adjust, rather than producing a one-shot schedule. -5. **Transparent.** Every move is logged. Easy to inspect, replay, and debug. -6. **Extensible metrics.** Adding a new metric is just adding a field to - `SchedulingMetrics` — no DSL changes needed. +1. **No abstraction gap.** The LLM sees actual MLIR IR, not a lossy summary. +2. **Zero new languages.** Three move commands, trivial to parse and validate. +3. **Fast feedback.** Invalid moves abort the round immediately with a clear + error report; the LLM learns from mistakes without silent corruption. +4. **Iterative refinement.** The LLM observes metrics after each round and + adjusts, rather than producing a one-shot schedule. +5. **Closed loop with hardware.** Optional `rocprof` integration feeds real + per-instruction latencies back, grounding the LLM in measured data. +6. **Transparent.** Every move is logged. Easy to inspect, replay, and debug. +7. **Extensible metrics.** Adding a new metric is just a field in + `SchedulingMetrics` — no protocol changes needed. From 7f5e72fde5daec38e0ec9bd3b177f5d8f99d798e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 18:49:57 +0100 Subject: [PATCH 03/49] parallelism Signed-off-by: Ivan Butygin --- .../include/waveasm/CONDUCTOR_DESIGN.md | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md index cf84cb02e..277069e7e 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md @@ -305,6 +305,58 @@ waitcnts > fewer nops. This ordering is configurable. - **Nested regions**: Tags are scoped per region. The LLM sees one region at a time (typically the innermost loop body). +## Parallel Search with Checkpoints + +Individual LLM calls have high latency but the API supports many concurrent +requests. Instead of running a single sequential loop, launch **P parallel +agents**, each exploring a different scheduling trajectory from the same +starting IR. + +``` + ┌─ Agent 1 ─── round 1 ─── round 2 ─── ... + │ +tagged IR ─┼─ Agent 2 ─── round 1 ─── round 2 ─── ... + │ + ├─ Agent 3 ─── round 1 ─── round 2 ─── ... + │ + └─ Agent P ─── round 1 ─── round 2 ─── ... + │ + checkpoint +``` + +At **checkpoints** (every K rounds), collect metrics from all agents, rank +them, and keep the **top-k**. The surviving agents continue from their current +state; the rest are terminated. New agents can be spawned from the top-k states +to refill the pool, optionally with varied system prompts (e.g. "prioritize +register pressure" vs "prioritize latency hiding"). + +``` +for checkpoint in 1..num_checkpoints: + // all agents run K rounds in parallel. + results = await all_agents(K rounds each) + + // rank by metrics, keep top-k. + ranked = sort(results, by=metrics) + survivors = ranked[:top_k] + + // respawn from survivors to refill pool. + agents = [] + for state in survivors: + agents.append(continue_agent(state)) + agents.append(fork_agent(state, varied_prompt)) // optional diversity. + +best = survivors[0] +``` + +This turns the scheduling search into a **beam search** over move sequences. +The validation + metrics infrastructure is unchanged — each agent is just an +independent instance of the iterative loop. The only coordination is at +checkpoints where we compare metrics and prune. + +Parallelism is free in terms of wall-clock time (bounded by the slowest agent +per checkpoint) and the compile-and-measure step is CPU-local and fast (~ms +per clone). + ## Caching Final move sequences are cached by SHA-256 of the **tagged IR text** (before From 7d143a96a59c4fcb2b96642d2f2715fc095de68d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 19:01:37 +0100 Subject: [PATCH 04/49] warnings Signed-off-by: Ivan Butygin --- .../include/waveasm/CONDUCTOR_DESIGN.md | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md index 277069e7e..2176e3da5 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md @@ -116,6 +116,11 @@ nops_inserted: 3 waitcnts_inserted: 12 total_instructions: 87 +--- Warnings from previous round --- +(none, or e.g.: + "warn: move I8 after I12 — moved LDS read after LDS write, same base address + info: move I3 after I7 — address computation moved away from consumer I4") + --- Error from previous round (if any) --- (none, or e.g.: "Applied successfully: move I5 after I1, move I8 before I3 @@ -175,6 +180,32 @@ All moves reverted. This gives the LLM a clear signal about what went wrong and a complete picture of which commands were valid up to the failure point. +### Warnings + +Some moves are structurally valid (dominance holds) but **potentially +dangerous**. These are applied but reported as warnings with severity levels, +so the LLM can decide whether to keep or reconsider them. + +| Severity | Example | Meaning | +|----------|---------|---------| +| `info` | Moving address computation away from its consumer. | Unusual but harmless. | +| `warn` | Moving a memory read after a memory write. | May be valid (different addresses) but could introduce a WAR hazard. | +| `warn` | Moving an op past a barrier. | Barrier ordering is usually intentional. | +| `critical` | Moving a memory write after another write to the same buffer. | Likely WAW hazard; almost certainly wrong unless offsets are provably disjoint. | + +The system performs lightweight alias analysis where possible (e.g. comparing +buffer SRD operands and constant offsets) to reduce false warnings. When it +cannot prove safety, it warns and lets the LLM proceed — the metrics from the +compilation round will reveal whether the move was actually beneficial. + +Warnings are reported alongside metrics in the next round: + +``` +--- Warnings --- +warn: move I8 after I12 — moved LDS read (I8) after LDS write (I12), same base address +info: move I3 after I7 — address computation moved away from consumer I4 +``` + ## Metrics Collection After applying a round of moves, the system runs the downstream WaveASM MLIR From 36c0414658f598b998fdf75a99ba30fd66a1a4c8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 22:48:41 +0100 Subject: [PATCH 05/49] fix gitignore --- wave_lang/kernel/wave/.gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/.gitignore b/wave_lang/kernel/wave/.gitignore index fd68a7a03..6d973f16e 100644 --- a/wave_lang/kernel/wave/.gitignore +++ b/wave_lang/kernel/wave/.gitignore @@ -11,7 +11,7 @@ dist/ downloads/ eggs/ .eggs/ -lib/ +# lib/ lib64/ parts/ sdist/ From 5bbc36008c12d827d89dd490ad95adfc10512355 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 22:50:32 +0100 Subject: [PATCH 06/49] tag instructions pass Signed-off-by: Ivan Butygin --- .../include/waveasm/Transforms/Passes.h | 3 + .../include/waveasm/Transforms/Passes.td | 23 +++++++ .../wave_asm/lib/Transforms/CMakeLists.txt | 1 + .../lib/Transforms/TagInstructions.cpp | 67 +++++++++++++++++++ .../waveasm-translate/waveasm-translate.cpp | 31 ++++++++- 5 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TagInstructions.cpp diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.h b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.h index 2d41e1d95..9d965561c 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.h +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.h @@ -57,6 +57,9 @@ std::unique_ptr createWAVEASMPeepholePass(); /// Create the memory offset optimization pass std::unique_ptr createWAVEASMMemoryOffsetOptPass(); +/// Create the instruction tagging pass (Conductor infrastructure). +std::unique_ptr createWAVEASMTagInstructionsPass(); + //===----------------------------------------------------------------------===// // Pass Registration //===----------------------------------------------------------------------===// diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.td b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.td index afaf96af1..9c1ec0bd9 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.td +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/Passes.td @@ -250,4 +250,27 @@ def WAVEASMEmitAssembly : Pass<"waveasm-emit-assembly"> { let dependentDialects = ["::waveasm::WaveASMDialect"]; } +//===----------------------------------------------------------------------===// +// Instruction Tagging Pass (Conductor) +//===----------------------------------------------------------------------===// + +def WAVEASMTagInstructions : Pass<"waveasm-tag-instructions", "mlir::ModuleOp"> { + let summary = "Attach stable NameLoc tags to WaveASM instructions"; + let description = [{ + Tags each WaveASM operation with a NameLoc of the form + "_" where N is a per-op-kind counter. For example: + buffer_load_dwordx4_0, v_mfma_f32_16x16x16_f16_0, ds_read_b64_2. + + Tags are stable across scheduling rounds — they follow the operation, + not the position. This is the first step of the Conductor scheduling + infrastructure. + }]; + + let statistics = [ + Statistic<"numOpsTagged", "Operations tagged", "count"> + ]; + + let dependentDialects = ["::waveasm::WaveASMDialect"]; +} + #endif // WaveASM_TRANSFORMS_PASSES diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt index 5dcd41422..8d945a87d 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt @@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRWaveASMTransforms ScopedCSE.cpp Peephole.cpp MemoryOffsetOptimization.cpp + TagInstructions.cpp ${HANDLERS_FULL_PATHS} ADDITIONAL_HEADER_DIRS diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TagInstructions.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TagInstructions.cpp new file mode 100644 index 000000000..d8b875d7b --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TagInstructions.cpp @@ -0,0 +1,67 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// Tag Instructions Pass - Attach stable NameLoc tags to WaveASM ops. +// +// Each operation gets a tag like loc("buffer_load_dwordx4_0"), +// loc("v_mfma_f32_16x16x16_f16_3"), etc. Tags are stable across +// scheduling rounds (they follow the operation, not the position). +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Dialect/WaveASMOps.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "waveasm-tag-instructions" + +using namespace mlir; +using namespace waveasm; + +namespace { + +struct TagInstructionsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TagInstructionsPass) + + StringRef getArgument() const override { return "waveasm-tag-instructions"; } + + StringRef getDescription() const override { + return "Attach stable NameLoc tags to WaveASM instructions"; + } + + void runOnOperation() override { + unsigned total = 0; + MLIRContext *ctx = &getContext(); + llvm::DenseMap counters; + + getOperation().walk([&](Operation *op) { + if (op->getDialect() && op->getDialect()->getNamespace() == "waveasm") { + StringRef opName = op->getName().stripDialect(); + unsigned idx = counters[opName]++; + std::string tag = (opName + "_" + llvm::Twine(idx)).str(); + op->setLoc(NameLoc::get(StringAttr::get(ctx, tag))); + ++total; + } + }); + + LDBG() << "tagged " << total << " operations"; + } +}; + +} // namespace + +namespace waveasm { + +std::unique_ptr createWAVEASMTagInstructionsPass() { + return std::make_unique(); +} + +} // namespace waveasm diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp index c8a8578a4..0d7b68f3e 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp @@ -100,6 +100,21 @@ static llvm::cl::opt llvm::cl::desc("Emit AMDGCN assembly instead of MLIR"), llvm::cl::init(false)); +static llvm::cl::opt runTagInstructions( + "waveasm-tag-instructions", + llvm::cl::desc("Attach stable NameLoc tags to WaveASM instructions"), + llvm::cl::init(false)); + +static llvm::cl::opt + printDebugLocs("print-debug-locs", + llvm::cl::desc("Print location information in MLIR output"), + llvm::cl::init(false)); + +static llvm::cl::opt printDebugLocsInline( + "print-debug-locs-inline", + llvm::cl::desc("Print location information inline (pretty form)"), + llvm::cl::init(false)); + static llvm::cl::opt runPreTranslationCSE( "mlir-cse", llvm::cl::desc("Run MLIR CSE before translation (reduces redundant index " @@ -249,6 +264,11 @@ int main(int argc, char **argv) { } } + // Instruction tagging for Conductor scheduling infrastructure. + if (runTagInstructions) { + pm.addPass(waveasm::createWAVEASMTagInstructionsPass()); + } + // Register allocation must run before waitcnt/hazard so that those passes // see the final register assignments. Matches compare_backends.py order: // LinearScan -> Waitcnt -> Hazard. @@ -330,8 +350,15 @@ int main(int argc, char **argv) { return success ? 0 : 1; } - // Print the translated module (MLIR format) - module->print(outputStream); + // Print the translated module (MLIR format). + OpPrintingFlags flags; + if (printDebugLocsInline) { + flags.enableDebugInfo(/*prettyForm=*/true); + flags.useLocalScope(); + } else if (printDebugLocs) { + flags.enableDebugInfo(); + } + module->print(outputStream, flags); return 0; } From b2f5f1fcd72fdce521daf514586f12f62aa98ce4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 22:51:08 +0100 Subject: [PATCH 07/49] test Signed-off-by: Ivan Butygin --- .../test/Transforms/tag-instructions.mlir | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 wave_lang/kernel/wave/asm/wave_asm/test/Transforms/tag-instructions.mlir diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/tag-instructions.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/tag-instructions.mlir new file mode 100644 index 000000000..c19183604 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/tag-instructions.mlir @@ -0,0 +1,48 @@ +// RUN: waveasm-translate --waveasm-tag-instructions --print-debug-locs-inline %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test: Each op kind gets its own counter. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @per_kind_counters +waveasm.program @per_kind_counters target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + + // Two adds should get v_add_u32_0 and v_add_u32_1. + // CHECK: waveasm.v_add_u32 {{.*}} loc("v_add_u32_0") + // CHECK: waveasm.v_add_u32 {{.*}} loc("v_add_u32_1") + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg + + // A shift gets its own counter starting at 0. + // CHECK: waveasm.v_lshlrev_b32 {{.*}} loc("v_lshlrev_b32_0") + %s0 = waveasm.v_lshlrev_b32 %c4, %a1 : !waveasm.imm<4>, !waveasm.vreg -> !waveasm.vreg + + // Load and store have independent counters. + // CHECK: waveasm.buffer_load_dwordx4 {{.*}} loc("buffer_load_dwordx4_0") + // CHECK: waveasm.buffer_store_dword {{.*}} loc("buffer_store_dword_0") + %ld = waveasm.buffer_load_dwordx4 %srd, %s0 : !waveasm.psreg<0, 4>, !waveasm.vreg -> !waveasm.vreg<4> + waveasm.buffer_store_dword %a1, %srd, %s0 : !waveasm.vreg, !waveasm.psreg<0, 4>, !waveasm.vreg + + // CHECK: waveasm.s_endpgm loc("s_endpgm_0") + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// Test: Counters reset per module, not per program. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @second_program +waveasm.program @second_program target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c1 = waveasm.constant 1 : !waveasm.imm<1> + + // Counter continues from first program (module-wide walk). + // CHECK: waveasm.v_add_u32 {{.*}} loc("v_add_u32_2") + %a0 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg + + // CHECK: waveasm.s_endpgm loc("s_endpgm_1") + waveasm.s_endpgm +} From fd892580080f7f2436c3337bc302d42c1a17ca87 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 23:18:34 +0100 Subject: [PATCH 08/49] conductor Signed-off-by: Ivan Butygin --- .../kernel/wave/asm/wave_asm/CMakeLists.txt | 1 + .../include/waveasm/Transforms/ApplyMoves.h | 42 ++++ .../wave_asm/lib/Transforms/ApplyMoves.cpp | 187 ++++++++++++++++++ .../wave_asm/lib/Transforms/CMakeLists.txt | 1 + .../wave/asm/wave_asm/test/CMakeLists.txt | 1 + .../test/Transforms/apply-moves-after.mlir | 22 +++ .../Transforms/apply-moves-error-pinned.mlir | 13 ++ .../apply-moves-error-unknown-tag.mlir | 13 ++ .../test/Transforms/apply-moves-swap.mlir | 22 +++ .../wave_asm/test/Transforms/apply-moves.mlir | 23 +++ .../kernel/wave/asm/wave_asm/test/lit.cfg.py | 4 +- .../wave/asm/wave_asm/test/lit.site.cfg.py.in | 1 + .../wave/asm/wave_asm/tools/CMakeLists.txt | 1 + .../tools/waveasm-conductor/CMakeLists.txt | 22 +++ .../waveasm-conductor/waveasm-conductor.cpp | 122 ++++++++++++ 15 files changed, 473 insertions(+), 2 deletions(-) create mode 100644 wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h create mode 100644 wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp create mode 100644 wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir create mode 100644 wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir create mode 100644 wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir create mode 100644 wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir create mode 100644 wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir create mode 100644 wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/CMakeLists.txt create mode 100644 wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp diff --git a/wave_lang/kernel/wave/asm/wave_asm/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/CMakeLists.txt index 3ea5c5316..bcffb1cbe 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/CMakeLists.txt @@ -51,6 +51,7 @@ add_definitions(${LLVM_DEFINITIONS}) set(WAVEASM_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(WAVEASM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(WAVEASM_TOOLS_DIR ${CMAKE_CURRENT_BINARY_DIR}/tools/waveasm-translate) +set(WAVEASM_CONDUCTOR_DIR ${CMAKE_CURRENT_BINARY_DIR}/tools/waveasm-conductor) #------------------------------------------------------------------------------- # Directory Configuration diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h new file mode 100644 index 000000000..5317b18bc --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h @@ -0,0 +1,42 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef WAVEASM_TRANSFORMS_APPLYMOVES_H +#define WAVEASM_TRANSFORMS_APPLYMOVES_H + +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/ArrayRef.h" +#include + +namespace waveasm { + +/// Result of applying a sequence of move commands to a module. +struct MoveResult { + bool success; + std::string error; + unsigned failedCommand; +}; + +/// Apply a sequence of Conductor move commands to a tagged module. +/// +/// Commands are strings of the form: +/// move after +/// move before +/// swap +/// +/// The module must already have NameLoc tags attached (via TagInstructions). +/// Returns a MoveResult indicating success or the first error encountered. +MoveResult applyMoves(mlir::ModuleOp module, + llvm::ArrayRef commands); + +/// Parse CONDUCTOR command lines from raw file text. +/// Scans for lines matching `// CONDUCTOR: ` and collects them +/// until a `done` command is found or input is exhausted. +llvm::SmallVector parseConductorCommands(llvm::StringRef text); + +} // namespace waveasm + +#endif // WAVEASM_TRANSFORMS_APPLYMOVES_H diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp new file mode 100644 index 000000000..798be779f --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp @@ -0,0 +1,187 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// ApplyMoves — parse and apply Conductor move/swap commands on tagged IR. +//===----------------------------------------------------------------------===// + +#include "waveasm/Transforms/ApplyMoves.h" + +#include "mlir/IR/Location.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "waveasm-apply-moves" + +using namespace mlir; + +namespace { + +/// Op mnemonics that must never be moved. +bool isPinned(Operation *op) { + StringRef name = op->getName().stripDialect(); + return name == "condition" || name == "s_endpgm" || name == "s_barrier"; +} + +/// Build a map from NameLoc tag string to Operation*. +llvm::StringMap buildTagMap(ModuleOp module) { + llvm::StringMap map; + module.walk([&](Operation *op) { + if (auto nameLoc = dyn_cast(op->getLoc())) + map[nameLoc.getName().strref()] = op; + }); + return map; +} + +/// Validate that an op can be moved (not pinned, resolves to a tag). +std::string validateMovable(Operation *op, StringRef tag) { + if (isPinned(op)) + return ("cannot move pinned op '" + tag + "'").str(); + return ""; +} + +/// Check that two ops are in the same block. +std::string validateSameBlock(Operation *a, StringRef tagA, Operation *b, + StringRef tagB) { + if (a->getBlock() != b->getBlock()) + return ("'" + tagA + "' and '" + tagB + "' are in different blocks").str(); + return ""; +} + +} // namespace + +namespace waveasm { + +llvm::SmallVector parseConductorCommands(llvm::StringRef text) { + llvm::SmallVector commands; + llvm::SmallVector lines; + text.split(lines, '\n'); + + for (StringRef line : lines) { + StringRef trimmed = line.ltrim(); + if (!trimmed.starts_with("// CONDUCTOR:")) + continue; + StringRef cmd = trimmed.drop_front(strlen("// CONDUCTOR:")).trim(); + if (cmd == "done") + break; + if (!cmd.empty()) + commands.push_back(cmd.str()); + } + return commands; +} + +MoveResult applyMoves(ModuleOp module, llvm::ArrayRef commands) { + auto tagMap = buildTagMap(module); + + for (auto [idx, cmd] : llvm::enumerate(commands)) { + StringRef line(cmd); + StringRef rest; + + auto fail = [&](const std::string &msg) -> MoveResult { + return {false, msg, static_cast(idx)}; + }; + + if (line.starts_with("move ")) { + rest = line.drop_front(strlen("move ")); + + // Parse: (after|before) . + StringRef tag1, direction, tag2; + std::tie(tag1, rest) = rest.split(' '); + std::tie(direction, tag2) = rest.split(' '); + + if (tag1.empty() || tag2.empty() || direction.empty()) + return fail("malformed move command: '" + cmd + "'"); + + if (direction != "after" && direction != "before") + return fail("expected 'after' or 'before', got '" + direction.str() + + "'"); + + auto it1 = tagMap.find(tag1); + if (it1 == tagMap.end()) + return fail("unknown tag '" + tag1.str() + "'"); + auto it2 = tagMap.find(tag2); + if (it2 == tagMap.end()) + return fail("unknown tag '" + tag2.str() + "'"); + + Operation *op1 = it1->second; + Operation *op2 = it2->second; + + std::string err = validateMovable(op1, tag1); + if (!err.empty()) + return fail(err); + + err = validateSameBlock(op1, tag1, op2, tag2); + if (!err.empty()) + return fail(err); + + if (direction == "after") + op1->moveAfter(op2); + else + op1->moveBefore(op2); + + LDBG() << "move " << tag1 << " " << direction << " " << tag2; + + } else if (line.starts_with("swap ")) { + rest = line.drop_front(strlen("swap ")); + + StringRef tag1, tag2; + std::tie(tag1, tag2) = rest.split(' '); + + if (tag1.empty() || tag2.empty()) + return fail("malformed swap command: '" + cmd + "'"); + + auto it1 = tagMap.find(tag1); + if (it1 == tagMap.end()) + return fail("unknown tag '" + tag1.str() + "'"); + auto it2 = tagMap.find(tag2); + if (it2 == tagMap.end()) + return fail("unknown tag '" + tag2.str() + "'"); + + Operation *op1 = it1->second; + Operation *op2 = it2->second; + + std::string err = validateMovable(op1, tag1); + if (!err.empty()) + return fail(err); + err = validateMovable(op2, tag2); + if (!err.empty()) + return fail(err); + + err = validateSameBlock(op1, tag1, op2, tag2); + if (!err.empty()) + return fail(err); + + // Swap: move op1 after op2, then move op2 to op1's original position. + Operation *op1Next = op1->getNextNode(); + if (op1Next == op2) { + // Adjacent: just swap order. + op1->moveAfter(op2); + } else { + Operation *op2Next = op2->getNextNode(); + if (op2Next == op1) { + op2->moveAfter(op1); + } else { + // Non-adjacent: use a stable reference point. + op1->moveAfter(op2); + if (op1Next) + op2->moveBefore(op1Next); + else + op2->moveBefore(op2->getBlock(), op2->getBlock()->end()); + } + } + + LDBG() << "swap " << tag1 << " " << tag2; + + } else { + return fail("unknown command: '" + cmd + "'"); + } + } + + return {true, "", 0}; +} + +} // namespace waveasm diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt index 8d945a87d..7094db564 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/CMakeLists.txt @@ -28,6 +28,7 @@ add_mlir_dialect_library(MLIRWaveASMTransforms Peephole.cpp MemoryOffsetOptimization.cpp TagInstructions.cpp + ApplyMoves.cpp ${HANDLERS_FULL_PATHS} ADDITIONAL_HEADER_DIRS diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/test/CMakeLists.txt index 2d78933d0..941777bf8 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/test/CMakeLists.txt @@ -14,6 +14,7 @@ configure_lit_site_cfg( set(WAVEASM_TEST_DEPENDS waveasm-translate + waveasm-conductor ) add_lit_testsuite(check-waveasm "Running the waveasm regression tests" diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir new file mode 100644 index 000000000..5ce3375e5 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir @@ -0,0 +1,22 @@ +// RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s + +// CONDUCTOR: move v_add_u32_0 after v_lshlrev_b32_0 +// CONDUCTOR: done + +// Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. +// After move: v_add_u32_1, v_lshlrev_b32_0, v_add_u32_0. + +// CHECK: sym_name = "test_move_after" +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") +// CHECK: waveasm.v_lshlrev_b32{{.*}}loc("v_lshlrev_b32_0") +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") +waveasm.program @test_move_after target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg + %s0 = waveasm.v_lshlrev_b32 %c4, %a1 : !waveasm.imm<4>, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir new file mode 100644 index 000000000..60c050e3d --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir @@ -0,0 +1,13 @@ +// RUN: not waveasm-conductor %s 2>&1 | FileCheck %s + +// CONDUCTOR: move s_endpgm_0 before v_add_u32_0 +// CONDUCTOR: done + +// CHECK: conductor: command 0: cannot move pinned op 's_endpgm_0' + +waveasm.program @test_pinned target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + waveasm.s_endpgm +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir new file mode 100644 index 000000000..ef40ecbb5 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir @@ -0,0 +1,13 @@ +// RUN: not waveasm-conductor %s 2>&1 | FileCheck %s + +// CONDUCTOR: move nonexistent_tag_42 after v_add_u32_0 +// CONDUCTOR: done + +// CHECK: conductor: command 0: unknown tag 'nonexistent_tag_42' + +waveasm.program @test_unknown_tag target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + waveasm.s_endpgm +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir new file mode 100644 index 000000000..479981190 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir @@ -0,0 +1,22 @@ +// RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s + +// CONDUCTOR: swap v_add_u32_0 v_lshlrev_b32_0 +// CONDUCTOR: done + +// Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. +// After swap(0, 2): v_lshlrev_b32_0, v_add_u32_1, v_add_u32_0. + +// CHECK: sym_name = "test_swap" +// CHECK: waveasm.v_lshlrev_b32{{.*}}loc("v_lshlrev_b32_0") +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") +waveasm.program @test_swap target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg + %s0 = waveasm.v_lshlrev_b32 %c4, %a1 : !waveasm.imm<4>, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir new file mode 100644 index 000000000..8d83f071b --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir @@ -0,0 +1,23 @@ +// RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s + +// CONDUCTOR: move v_add_u32_1 before v_add_u32_0 +// CONDUCTOR: done + +// Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. +// After move: v_add_u32_1, v_add_u32_0, v_lshlrev_b32_0. +// Note: output is in generic form because moves may create use-before-def. + +// CHECK: sym_name = "test_move_before" +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") +// CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") +// CHECK: waveasm.v_lshlrev_b32{{.*}}loc("v_lshlrev_b32_0") +waveasm.program @test_move_before target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg + %s0 = waveasm.v_lshlrev_b32 %c4, %a1 : !waveasm.imm<4>, !waveasm.vreg -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py b/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py index 2ab708d36..2c16266de 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py +++ b/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py @@ -33,8 +33,8 @@ llvm_config.use_default_substitutions() # Add tools to the path -tool_dirs = [config.waveasm_tools_dir, config.llvm_tools_dir] -tools = ["waveasm-translate", "FileCheck", "count", "not"] +tool_dirs = [config.waveasm_tools_dir, config.waveasm_conductor_dir, config.llvm_tools_dir] +tools = ["waveasm-translate", "waveasm-conductor", "FileCheck", "count", "not"] llvm_config.add_tool_substitutions(tools, tool_dirs) # ROCm toolchain detection for integration tests diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/lit.site.cfg.py.in b/wave_lang/kernel/wave/asm/wave_asm/test/lit.site.cfg.py.in index 83304e850..9c6bd1ee0 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/lit.site.cfg.py.in +++ b/wave_lang/kernel/wave/asm/wave_asm/test/lit.site.cfg.py.in @@ -20,6 +20,7 @@ config.llvm_shlib_ext = "@SHLIBEXT@" config.waveasm_obj_root = "@WAVEASM_BINARY_DIR@" config.waveasm_src_root = "@WAVEASM_SOURCE_DIR@" config.waveasm_tools_dir = "@WAVEASM_TOOLS_DIR@" +config.waveasm_conductor_dir = "@WAVEASM_CONDUCTOR_DIR@" import lit.llvm lit.llvm.initialize(lit_config, config) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/tools/CMakeLists.txt index b8c29c2a5..c2ed906e7 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/CMakeLists.txt +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/CMakeLists.txt @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(waveasm-translate) +add_subdirectory(waveasm-conductor) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/CMakeLists.txt b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/CMakeLists.txt new file mode 100644 index 000000000..2ddc6b4ef --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_llvm_executable(waveasm-conductor + waveasm-conductor.cpp +) + +llvm_update_compile_flags(waveasm-conductor) + +target_link_libraries(waveasm-conductor + PRIVATE + MLIRWaveASMDialect + MLIRWaveASMTransforms + MLIRParser + MLIRPass + MLIRSupport + MLIRIR + LLVMSupport +) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp new file mode 100644 index 000000000..f18ae4371 --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -0,0 +1,122 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// waveasm-conductor: CLI tool for applying Conductor move commands. +// +// Reads tagged WaveASM IR with embedded // CONDUCTOR: comments, +// parses and applies the move/swap commands, then prints the result. +//===----------------------------------------------------------------------===// + +#include "waveasm/Dialect/WaveASMDialect.h" +#include "waveasm/Transforms/ApplyMoves.h" +#include "waveasm/Transforms/Passes.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Command Line Options +//===----------------------------------------------------------------------===// + +static llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +static llvm::cl::opt + outputFilename("o", llvm::cl::desc("Output filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("-")); + +static llvm::cl::opt printDebugLocsInline( + "print-debug-locs-inline", + llvm::cl::desc("Print location information inline (pretty form)"), + llvm::cl::init(false)); + +//===----------------------------------------------------------------------===// +// Main Function +//===----------------------------------------------------------------------===// + +int main(int argc, char **argv) { + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv, + "WAVEASM Conductor move executor\n"); + + // Read the raw input file to extract CONDUCTOR commands before parsing. + auto inputFileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (auto ec = inputFileOrErr.getError()) { + llvm::errs() << "Error reading input file: " << ec.message() << "\n"; + return 1; + } + + llvm::StringRef rawText = (*inputFileOrErr)->getBuffer(); + auto commands = waveasm::parseConductorCommands(rawText); + + if (commands.empty()) { + llvm::errs() << "No CONDUCTOR commands found in input\n"; + return 1; + } + + // Set up MLIR context and parse the module. + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.loadAllAvailableDialects(); + context.allowUnregisteredDialects(); + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*inputFileOrErr), llvm::SMLoc()); + + OwningOpRef module = parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Failed to parse input file\n"; + return 1; + } + + // Run tag-instructions pass to attach NameLoc tags. + PassManager pm(&context); + pm.addPass(waveasm::createWAVEASMTagInstructionsPass()); + if (failed(pm.run(*module))) { + llvm::errs() << "Tag-instructions pass failed\n"; + return 1; + } + + // Apply the move commands. + waveasm::MoveResult result = waveasm::applyMoves(*module, commands); + if (!result.success) { + llvm::errs() << "conductor: command " << result.failedCommand << ": " + << result.error << "\n"; + return 1; + } + + // Print the result. + std::error_code ec; + llvm::raw_fd_ostream outputStream(outputFilename, ec); + if (ec) { + llvm::errs() << "Error opening output file: " << ec.message() << "\n"; + return 1; + } + + OpPrintingFlags flags; + if (printDebugLocsInline) { + flags.enableDebugInfo(/*prettyForm=*/true); + flags.useLocalScope(); + } + module->print(outputStream, flags); + + return 0; +} From 537e3e949c27589b19184f8d346de81e84b64433 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 23:19:39 +0100 Subject: [PATCH 09/49] update lit Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py b/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py index 2c16266de..12c3b8ceb 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py +++ b/wave_lang/kernel/wave/asm/wave_asm/test/lit.cfg.py @@ -33,7 +33,11 @@ llvm_config.use_default_substitutions() # Add tools to the path -tool_dirs = [config.waveasm_tools_dir, config.waveasm_conductor_dir, config.llvm_tools_dir] +tool_dirs = [ + config.waveasm_tools_dir, + config.waveasm_conductor_dir, + config.llvm_tools_dir, +] tools = ["waveasm-translate", "waveasm-conductor", "FileCheck", "count", "not"] llvm_config.add_tool_substitutions(tools, tool_dirs) From e9f40d3b6b041c838b0d6ee45df4b606f9a2ba97 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 23:28:08 +0100 Subject: [PATCH 10/49] update parsing Signed-off-by: Ivan Butygin --- .../include/waveasm/Transforms/ApplyMoves.h | 50 +++-- .../wave_asm/lib/Transforms/ApplyMoves.cpp | 196 ++++++++++-------- .../waveasm-conductor/waveasm-conductor.cpp | 12 +- 3 files changed, 154 insertions(+), 104 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h index 5317b18bc..5395d06e3 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h @@ -9,10 +9,32 @@ #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include +#include namespace waveasm { +/// Move an op before a reference op. +struct MoveBefore { + std::string tag; + std::string refTag; +}; + +/// Move an op after a reference op. +struct MoveAfter { + std::string tag; + std::string refTag; +}; + +/// Swap two ops. +struct Swap { + std::string tag1; + std::string tag2; +}; + +using Command = std::variant; + /// Result of applying a sequence of move commands to a module. struct MoveResult { bool success; @@ -20,22 +42,24 @@ struct MoveResult { unsigned failedCommand; }; -/// Apply a sequence of Conductor move commands to a tagged module. -/// -/// Commands are strings of the form: -/// move after -/// move before -/// swap +/// Result of parsing CONDUCTOR commands from raw text. +struct ParseResult { + bool success; + std::string error; + unsigned failedLine; // 0-based index of the failing CONDUCTOR line. + llvm::SmallVector commands; +}; + +/// Parse CONDUCTOR command lines from raw file text. +/// Scans for lines matching `// CONDUCTOR: ` and parses them +/// into typed Command structs until `done` or end of input. +ParseResult parseConductorCommands(llvm::StringRef text); + +/// Apply a sequence of parsed Conductor commands to a tagged module. /// /// The module must already have NameLoc tags attached (via TagInstructions). /// Returns a MoveResult indicating success or the first error encountered. -MoveResult applyMoves(mlir::ModuleOp module, - llvm::ArrayRef commands); - -/// Parse CONDUCTOR command lines from raw file text. -/// Scans for lines matching `// CONDUCTOR: ` and collects them -/// until a `done` command is found or input is exhausted. -llvm::SmallVector parseConductorCommands(llvm::StringRef text); +MoveResult applyMoves(mlir::ModuleOp module, llvm::ArrayRef commands); } // namespace waveasm diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp index 798be779f..54603135a 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp @@ -37,7 +37,7 @@ llvm::StringMap buildTagMap(ModuleOp module) { return map; } -/// Validate that an op can be moved (not pinned, resolves to a tag). +/// Validate that an op can be moved (not pinned). std::string validateMovable(Operation *op, StringRef tag) { if (isPinned(op)) return ("cannot move pinned op '" + tag + "'").str(); @@ -56,128 +56,148 @@ std::string validateSameBlock(Operation *a, StringRef tagA, Operation *b, namespace waveasm { -llvm::SmallVector parseConductorCommands(llvm::StringRef text) { - llvm::SmallVector commands; +ParseResult parseConductorCommands(llvm::StringRef text) { + ParseResult result; + result.success = true; + result.failedLine = 0; + llvm::SmallVector lines; text.split(lines, '\n'); + unsigned cmdIdx = 0; for (StringRef line : lines) { StringRef trimmed = line.ltrim(); if (!trimmed.starts_with("// CONDUCTOR:")) continue; - StringRef cmd = trimmed.drop_front(strlen("// CONDUCTOR:")).trim(); - if (cmd == "done") + StringRef raw = trimmed.drop_front(strlen("// CONDUCTOR:")).trim(); + if (raw == "done") break; - if (!cmd.empty()) - commands.push_back(cmd.str()); - } - return commands; -} - -MoveResult applyMoves(ModuleOp module, llvm::ArrayRef commands) { - auto tagMap = buildTagMap(module); - - for (auto [idx, cmd] : llvm::enumerate(commands)) { - StringRef line(cmd); - StringRef rest; - - auto fail = [&](const std::string &msg) -> MoveResult { - return {false, msg, static_cast(idx)}; - }; - - if (line.starts_with("move ")) { - rest = line.drop_front(strlen("move ")); + if (raw.empty()) + continue; - // Parse: (after|before) . - StringRef tag1, direction, tag2; - std::tie(tag1, rest) = rest.split(' '); - std::tie(direction, tag2) = rest.split(' '); + if (raw.starts_with("move ")) { + StringRef rest = raw.drop_front(strlen("move ")); + auto [tag, rest2] = rest.split(' '); + auto [direction, refTag] = rest2.split(' '); - if (tag1.empty() || tag2.empty() || direction.empty()) - return fail("malformed move command: '" + cmd + "'"); + if (tag.empty() || refTag.empty() || direction.empty()) { + result.success = false; + result.error = ("malformed move command: '" + raw + "'").str(); + result.failedLine = cmdIdx; + return result; + } - if (direction != "after" && direction != "before") - return fail("expected 'after' or 'before', got '" + direction.str() + - "'"); + if (direction == "before") { + result.commands.push_back(MoveBefore{tag.str(), refTag.str()}); + } else if (direction == "after") { + result.commands.push_back(MoveAfter{tag.str(), refTag.str()}); + } else { + result.success = false; + result.error = + ("expected 'after' or 'before', got '" + direction + "'").str(); + result.failedLine = cmdIdx; + return result; + } - auto it1 = tagMap.find(tag1); - if (it1 == tagMap.end()) - return fail("unknown tag '" + tag1.str() + "'"); - auto it2 = tagMap.find(tag2); - if (it2 == tagMap.end()) - return fail("unknown tag '" + tag2.str() + "'"); + } else if (raw.starts_with("swap ")) { + StringRef rest = raw.drop_front(strlen("swap ")); + auto [tag1, tag2] = rest.split(' '); - Operation *op1 = it1->second; - Operation *op2 = it2->second; + if (tag1.empty() || tag2.empty()) { + result.success = false; + result.error = ("malformed swap command: '" + raw + "'").str(); + result.failedLine = cmdIdx; + return result; + } - std::string err = validateMovable(op1, tag1); - if (!err.empty()) - return fail(err); + result.commands.push_back(Swap{tag1.str(), tag2.str()}); - err = validateSameBlock(op1, tag1, op2, tag2); - if (!err.empty()) - return fail(err); - - if (direction == "after") - op1->moveAfter(op2); - else - op1->moveBefore(op2); + } else { + result.success = false; + result.error = ("unknown command: '" + raw + "'").str(); + result.failedLine = cmdIdx; + return result; + } - LDBG() << "move " << tag1 << " " << direction << " " << tag2; + ++cmdIdx; + } - } else if (line.starts_with("swap ")) { - rest = line.drop_front(strlen("swap ")); + return result; +} - StringRef tag1, tag2; - std::tie(tag1, tag2) = rest.split(' '); +/// Resolve two tags, validate movability and same-block constraint. +/// On success sets op1/op2 and returns empty string. +static std::string +resolveAndValidate(const llvm::StringMap &tagMap, StringRef tag, + StringRef ref, bool checkRefMovable, Operation *&op1, + Operation *&op2) { + auto it1 = tagMap.find(tag); + if (it1 == tagMap.end()) + return ("unknown tag '" + tag + "'").str(); + auto it2 = tagMap.find(ref); + if (it2 == tagMap.end()) + return ("unknown tag '" + ref + "'").str(); + + op1 = it1->second; + op2 = it2->second; + + std::string err = validateMovable(op1, tag); + if (!err.empty()) + return err; + if (checkRefMovable) { + err = validateMovable(op2, ref); + if (!err.empty()) + return err; + } + return validateSameBlock(op1, tag, op2, ref); +} - if (tag1.empty() || tag2.empty()) - return fail("malformed swap command: '" + cmd + "'"); +MoveResult applyMoves(ModuleOp module, llvm::ArrayRef commands) { + auto tagMap = buildTagMap(module); - auto it1 = tagMap.find(tag1); - if (it1 == tagMap.end()) - return fail("unknown tag '" + tag1.str() + "'"); - auto it2 = tagMap.find(tag2); - if (it2 == tagMap.end()) - return fail("unknown tag '" + tag2.str() + "'"); + for (auto [idx, cmd] : llvm::enumerate(commands)) { + auto fail = [&](const std::string &msg) -> MoveResult { + return {false, msg, static_cast(idx)}; + }; - Operation *op1 = it1->second; - Operation *op2 = it2->second; + Operation *op1 = nullptr, *op2 = nullptr; - std::string err = validateMovable(op1, tag1); + if (auto *move = std::get_if(&cmd)) { + std::string err = + resolveAndValidate(tagMap, move->tag, move->refTag, false, op1, op2); if (!err.empty()) return fail(err); - err = validateMovable(op2, tag2); + op1->moveBefore(op2); + LDBG() << "move " << move->tag << " before " << move->refTag; + + } else if (auto *move = std::get_if(&cmd)) { + std::string err = + resolveAndValidate(tagMap, move->tag, move->refTag, false, op1, op2); if (!err.empty()) return fail(err); + op1->moveAfter(op2); + LDBG() << "move " << move->tag << " after " << move->refTag; - err = validateSameBlock(op1, tag1, op2, tag2); + } else if (auto *swap = std::get_if(&cmd)) { + std::string err = + resolveAndValidate(tagMap, swap->tag1, swap->tag2, true, op1, op2); if (!err.empty()) return fail(err); - // Swap: move op1 after op2, then move op2 to op1's original position. + // Swap by considering adjacency cases. Operation *op1Next = op1->getNextNode(); if (op1Next == op2) { - // Adjacent: just swap order. op1->moveAfter(op2); + } else if (op2->getNextNode() == op1) { + op2->moveAfter(op1); } else { - Operation *op2Next = op2->getNextNode(); - if (op2Next == op1) { - op2->moveAfter(op1); - } else { - // Non-adjacent: use a stable reference point. - op1->moveAfter(op2); - if (op1Next) - op2->moveBefore(op1Next); - else - op2->moveBefore(op2->getBlock(), op2->getBlock()->end()); - } + op1->moveAfter(op2); + if (op1Next) + op2->moveBefore(op1Next); + else + op2->moveBefore(op2->getBlock(), op2->getBlock()->end()); } - - LDBG() << "swap " << tag1 << " " << tag2; - - } else { - return fail("unknown command: '" + cmd + "'"); + LDBG() << "swap " << swap->tag1 << " " << swap->tag2; } } diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp index f18ae4371..9dc7d28b9 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -63,9 +63,14 @@ int main(int argc, char **argv) { } llvm::StringRef rawText = (*inputFileOrErr)->getBuffer(); - auto commands = waveasm::parseConductorCommands(rawText); + auto parseResult = waveasm::parseConductorCommands(rawText); - if (commands.empty()) { + if (!parseResult.success) { + llvm::errs() << "conductor: parse error at command " + << parseResult.failedLine << ": " << parseResult.error << "\n"; + return 1; + } + if (parseResult.commands.empty()) { llvm::errs() << "No CONDUCTOR commands found in input\n"; return 1; } @@ -96,7 +101,8 @@ int main(int argc, char **argv) { } // Apply the move commands. - waveasm::MoveResult result = waveasm::applyMoves(*module, commands); + waveasm::MoveResult result = + waveasm::applyMoves(*module, parseResult.commands); if (!result.success) { llvm::errs() << "conductor: command " << result.failedCommand << ": " << result.error << "\n"; From 6b72fa07642a094168cd3715a1152eba3b3f243a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 21 Feb 2026 23:51:55 +0100 Subject: [PATCH 11/49] remove "done" --- .../wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h | 2 +- .../kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp | 2 -- .../wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir | 2 +- .../asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir | 2 +- .../wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir | 2 +- .../wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir | 2 +- .../kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir | 2 +- 7 files changed, 6 insertions(+), 8 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h index 5395d06e3..69d9b99a3 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h +++ b/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/Transforms/ApplyMoves.h @@ -52,7 +52,7 @@ struct ParseResult { /// Parse CONDUCTOR command lines from raw file text. /// Scans for lines matching `// CONDUCTOR: ` and parses them -/// into typed Command structs until `done` or end of input. +/// into typed Command structs. ParseResult parseConductorCommands(llvm::StringRef text); /// Apply a sequence of parsed Conductor commands to a tagged module. diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp index 54603135a..aa77b586b 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp @@ -70,8 +70,6 @@ ParseResult parseConductorCommands(llvm::StringRef text) { if (!trimmed.starts_with("// CONDUCTOR:")) continue; StringRef raw = trimmed.drop_front(strlen("// CONDUCTOR:")).trim(); - if (raw == "done") - break; if (raw.empty()) continue; diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir index 5ce3375e5..3c80176ff 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir @@ -1,7 +1,7 @@ // RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s // CONDUCTOR: move v_add_u32_0 after v_lshlrev_b32_0 -// CONDUCTOR: done + // Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. // After move: v_add_u32_1, v_lshlrev_b32_0, v_add_u32_0. diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir index 60c050e3d..d62fafb5d 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir @@ -1,7 +1,7 @@ // RUN: not waveasm-conductor %s 2>&1 | FileCheck %s // CONDUCTOR: move s_endpgm_0 before v_add_u32_0 -// CONDUCTOR: done + // CHECK: conductor: command 0: cannot move pinned op 's_endpgm_0' diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir index ef40ecbb5..9412e59f8 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-unknown-tag.mlir @@ -1,7 +1,7 @@ // RUN: not waveasm-conductor %s 2>&1 | FileCheck %s // CONDUCTOR: move nonexistent_tag_42 after v_add_u32_0 -// CONDUCTOR: done + // CHECK: conductor: command 0: unknown tag 'nonexistent_tag_42' diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir index 479981190..a114a0fd9 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir @@ -1,7 +1,7 @@ // RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s // CONDUCTOR: swap v_add_u32_0 v_lshlrev_b32_0 -// CONDUCTOR: done + // Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. // After swap(0, 2): v_lshlrev_b32_0, v_add_u32_1, v_add_u32_0. diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir index 8d83f071b..b1eb8de02 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir @@ -1,7 +1,7 @@ // RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s // CONDUCTOR: move v_add_u32_1 before v_add_u32_0 -// CONDUCTOR: done + // Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. // After move: v_add_u32_1, v_add_u32_0, v_lshlrev_b32_0. From 978c126777dae8747028941f6ea820556d8485c1 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 00:10:51 +0100 Subject: [PATCH 12/49] verify in conductor Signed-off-by: Ivan Butygin --- .../test/Transforms/apply-moves-after.mlir | 13 +++++-------- .../apply-moves-error-dominance.mlir | 19 +++++++++++++++++++ .../test/Transforms/apply-moves-swap.mlir | 10 ++++------ .../wave_asm/test/Transforms/apply-moves.mlir | 12 ++++-------- .../waveasm-conductor/waveasm-conductor.cpp | 7 +++++++ 5 files changed, 39 insertions(+), 22 deletions(-) create mode 100644 wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir index 3c80176ff..555dec184 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir @@ -1,22 +1,19 @@ // RUN: waveasm-conductor --print-debug-locs-inline %s | FileCheck %s -// CONDUCTOR: move v_add_u32_0 after v_lshlrev_b32_0 +// CONDUCTOR: move v_add_u32_0 after v_add_u32_1 +// Two independent adds — reordering is safe. -// Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. -// After move: v_add_u32_1, v_lshlrev_b32_0, v_add_u32_0. - -// CHECK: sym_name = "test_move_after" +// CHECK-LABEL: waveasm.program @test_move_after // CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") -// CHECK: waveasm.v_lshlrev_b32{{.*}}loc("v_lshlrev_b32_0") // CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") waveasm.program @test_move_after target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> %c4 = waveasm.constant 4 : !waveasm.imm<4> + %c1 = waveasm.constant 1 : !waveasm.imm<1> %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg - %s0 = waveasm.v_lshlrev_b32 %c4, %a1 : !waveasm.imm<4>, !waveasm.vreg -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg waveasm.s_endpgm } diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir new file mode 100644 index 000000000..b3d01551d --- /dev/null +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir @@ -0,0 +1,19 @@ +// RUN: not waveasm-conductor %s 2>&1 | FileCheck %s + +// CONDUCTOR: move v_add_u32_0 after v_add_u32_1 + +// v_add_u32_1 uses the result of v_add_u32_0, so moving _0 after _1 +// breaks dominance. + +// CHECK: does not dominate this use +// CHECK: conductor: verification failed after applying moves + +waveasm.program @test_dominance target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %c4 = waveasm.constant 4 : !waveasm.imm<4> + + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg + + waveasm.s_endpgm +} diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir index a114a0fd9..f56ef1936 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir @@ -2,11 +2,9 @@ // CONDUCTOR: swap v_add_u32_0 v_lshlrev_b32_0 +// Three independent ops (all read from %v0/%c4) — swap is safe. -// Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. -// After swap(0, 2): v_lshlrev_b32_0, v_add_u32_1, v_add_u32_0. - -// CHECK: sym_name = "test_swap" +// CHECK-LABEL: waveasm.program @test_swap // CHECK: waveasm.v_lshlrev_b32{{.*}}loc("v_lshlrev_b32_0") // CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") // CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") @@ -15,8 +13,8 @@ waveasm.program @test_swap target = #waveasm.target<#waveasm.gfx942, 5> abi = #w %c4 = waveasm.constant 4 : !waveasm.imm<4> %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg - %s0 = waveasm.v_lshlrev_b32 %c4, %a1 : !waveasm.imm<4>, !waveasm.vreg -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg + %s0 = waveasm.v_lshlrev_b32 %c4, %v0 : !waveasm.imm<4>, !waveasm.pvreg<0> -> !waveasm.vreg waveasm.s_endpgm } diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir index b1eb8de02..a5818db1f 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir @@ -2,22 +2,18 @@ // CONDUCTOR: move v_add_u32_1 before v_add_u32_0 +// Two independent adds from the same inputs — reordering is safe. -// Original order: v_add_u32_0, v_add_u32_1, v_lshlrev_b32_0. -// After move: v_add_u32_1, v_add_u32_0, v_lshlrev_b32_0. -// Note: output is in generic form because moves may create use-before-def. - -// CHECK: sym_name = "test_move_before" +// CHECK-LABEL: waveasm.program @test_move_before // CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_1") // CHECK: waveasm.v_add_u32{{.*}}loc("v_add_u32_0") -// CHECK: waveasm.v_lshlrev_b32{{.*}}loc("v_lshlrev_b32_0") waveasm.program @test_move_before target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> %c4 = waveasm.constant 4 : !waveasm.imm<4> + %c1 = waveasm.constant 1 : !waveasm.imm<1> %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg - %s0 = waveasm.v_lshlrev_b32 %c4, %a1 : !waveasm.imm<4>, !waveasm.vreg -> !waveasm.vreg + %a1 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg waveasm.s_endpgm } diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp index 9dc7d28b9..1c6411ed3 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" #include "llvm/Support/CommandLine.h" @@ -109,6 +110,12 @@ int main(int argc, char **argv) { return 1; } + // Verify the module after moves (catches broken dominance, etc.). + if (failed(mlir::verify(*module))) { + llvm::errs() << "conductor: verification failed after applying moves\n"; + return 1; + } + // Print the result. std::error_code ec; llvm::raw_fd_ostream outputStream(outputFilename, ec); From a11743bd9b3edad9af6c27246464efcf6ddda301 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 00:45:08 +0100 Subject: [PATCH 13/49] conductor Signed-off-by: Ivan Butygin --- conductor/__init__.py | 28 ++++ conductor/conductor.py | 207 ++++++++++++++++++++++++ conductor/extract_ir.py | 343 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 578 insertions(+) create mode 100644 conductor/__init__.py create mode 100644 conductor/conductor.py create mode 100644 conductor/extract_ir.py diff --git a/conductor/__init__.py b/conductor/__init__.py new file mode 100644 index 000000000..e7fad7bb1 --- /dev/null +++ b/conductor/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Conductor: LLM-guided instruction scheduling for WaveASM.""" + +from conductor.conductor import Conductor, find_waveasm_conductor +from conductor.extract_ir import ( + find_waveasm_translate, + run_waveasm_translate, + run_pre_scheduling_pipeline, + run_full_pipeline, + count_asm_metrics, + capture_kernel_mlir, +) + +__all__ = [ + "Conductor", + "find_waveasm_conductor", + "find_waveasm_translate", + "run_waveasm_translate", + "run_pre_scheduling_pipeline", + "run_full_pipeline", + "count_asm_metrics", + "capture_kernel_mlir", +] diff --git a/conductor/conductor.py b/conductor/conductor.py new file mode 100644 index 000000000..0b920b0ad --- /dev/null +++ b/conductor/conductor.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +""" +Conductor: Python driver for WaveASM instruction scheduling experiments. + +Ties together extract_ir → tag → apply moves → post-scheduling pipeline → metrics. + +Usage: + # Baseline (no moves): + python -m conductor.conductor --metrics + + # Apply specific moves: + python -m conductor.conductor \ + --moves "swap buffer_load_dwordx4_0 v_mfma_f32_16x16x16_f16_0" --metrics + + # Show tagged IR only: + python -m conductor.conductor --tag-only + + # Read moves from a file: + python -m conductor.conductor --moves-file moves.txt --metrics +""" + +import argparse +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +# Add wave_lang to path. +wave_root = Path(__file__).parent.parent +sys.path.insert(0, str(wave_root)) + +from conductor.extract_ir import ( + find_waveasm_translate, + run_waveasm_translate, + count_asm_metrics, + capture_kernel_mlir, + run_pre_scheduling_pipeline, +) + + +def find_waveasm_conductor() -> str: + """Find the waveasm-conductor binary.""" + env_path = os.environ.get("WAVEASM_CONDUCTOR") + if env_path and Path(env_path).exists(): + return env_path + + candidates = [ + wave_root / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "build" / "tools" / "waveasm-conductor" / "waveasm-conductor", + wave_root / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "build" / "bin" / "waveasm-conductor", + ] + for p in candidates: + if p.exists(): + return str(p) + + raise FileNotFoundError( + "waveasm-conductor not found. Set WAVEASM_CONDUCTOR env var or build it." + ) + + +class Conductor: + """Encapsulates the full round-trip for instruction scheduling experiments.""" + + def __init__(self, waveasm_ir: str, workgroup_size: tuple, target: str = "gfx942"): + self.waveasm_ir = waveasm_ir + self.workgroup_size = workgroup_size + self.target = target + self._baseline_cache = None + + def tag(self) -> str: + """Run tag-instructions on the IR. Return tagged IR text.""" + flags = [ + "--waveasm-tag-instructions", + "--print-debug-locs-inline", + ] + stdout, stderr, rc = run_waveasm_translate( + self.waveasm_ir, self.workgroup_size, flags + ) + if rc != 0: + raise RuntimeError(f"tag-instructions failed:\n{stderr}") + return stdout + + def apply_moves(self, tagged_ir: str, commands: list) -> str: + """Apply CONDUCTOR move commands to tagged IR. Return reordered IR.""" + conductor = find_waveasm_conductor() + + # Prepend CONDUCTOR commands to the tagged IR. + header = "\n".join(f"// CONDUCTOR: {cmd}" for cmd in commands) + full_input = header + "\n\n" + tagged_ir + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".mlir", delete=False + ) as f: + f.write(full_input) + input_path = f.name + + try: + cmd = [conductor, "--print-debug-locs-inline", input_path] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + if result.returncode != 0: + raise RuntimeError( + f"waveasm-conductor failed (rc={result.returncode}):\n{result.stderr}" + ) + return result.stdout + finally: + os.unlink(input_path) + + def compile_to_asm(self, ir: str) -> str: + """Run post-scheduling pipeline on already-scheduled WaveASM IR.""" + flags = [ + "--waveasm-linear-scan", + "--max-vgprs=512", + "--max-agprs=512", + "--waveasm-insert-waitcnt", + "--waveasm-hazard-mitigation", + "--emit-assembly", + ] + stdout, stderr, rc = run_waveasm_translate(ir, self.workgroup_size, flags) + if rc != 0: + raise RuntimeError(f"compile_to_asm failed:\n{stderr}") + return stdout + + def get_metrics(self, asm: str) -> dict: + """Extract metrics from assembly text.""" + return count_asm_metrics(asm) + + def evaluate(self, commands: list) -> dict: + """ + Full round-trip: tag → apply_moves → compile_to_asm → get_metrics. + + This is the main entry point for a search algorithm. + """ + tagged = self.tag() + reordered = self.apply_moves(tagged, commands) + asm = self.compile_to_asm(reordered) + return self.get_metrics(asm) + + def baseline(self) -> dict: + """Evaluate with no moves (identity schedule). Caches result.""" + if self._baseline_cache is None: + asm = self.compile_to_asm(self.waveasm_ir) + self._baseline_cache = self.get_metrics(asm) + return self._baseline_cache + + +def main(): + parser = argparse.ArgumentParser( + description="Conductor: WaveASM instruction scheduling driver." + ) + parser.add_argument( + "--moves", nargs="*", default=None, + help="Move commands (e.g. 'swap A B', 'move X after Y').", + ) + parser.add_argument( + "--moves-file", type=str, default=None, + help="Read move commands from a file (one per line).", + ) + parser.add_argument( + "--metrics", action="store_true", + help="Print assembly metrics after scheduling.", + ) + parser.add_argument( + "--tag-only", action="store_true", + help="Only show tagged IR, then exit.", + ) + args = parser.parse_args() + + # Collect commands from both sources. + commands = [] + if args.moves: + commands.extend(args.moves) + if args.moves_file: + commands.extend( + line.strip() + for line in Path(args.moves_file).read_text().splitlines() + if line.strip() and not line.strip().startswith("#") + ) + + print("Capturing kernel MLIR...", file=sys.stderr) + mlir_text, wg_size = capture_kernel_mlir() + print(f" workgroup_size: {wg_size}", file=sys.stderr) + + print("Running pre-scheduling pipeline...", file=sys.stderr) + waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size) + print(f" WaveASM IR: {len(waveasm_ir)} chars", file=sys.stderr) + + conductor = Conductor(waveasm_ir, wg_size) + + if args.tag_only: + print(conductor.tag()) + return + + if commands: + print(f"Applying {len(commands)} move(s)...", file=sys.stderr) + metrics = conductor.evaluate(commands) + else: + print("No moves specified, running baseline...", file=sys.stderr) + metrics = conductor.baseline() + + if args.metrics: + print("\n=== Metrics ===") + for k, v in metrics.items(): + print(f" {k}: {v}") + + +if __name__ == "__main__": + main() diff --git a/conductor/extract_ir.py b/conductor/extract_ir.py new file mode 100644 index 000000000..be144f765 --- /dev/null +++ b/conductor/extract_ir.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +""" +Extract WaveASM MLIR IR at the Conductor scheduling stage. + +Defines a GEMM kernel inline, runs the Wave compilation pipeline to produce +input MLIR, then runs waveasm-translate with only the pre-scheduling passes +(CSE, peephole, memory-offset-opt). The output is the WaveASM IR that the +Conductor would see — post-optimization, pre-register-allocation. + +Usage: + python -m conductor.extract_ir -o /tmp/conductor_ir.mlir + python -m conductor.extract_ir --metrics + python -m conductor.extract_ir --dump-input-mlir + +Requires: + - waveasm-translate binary (auto-detected from build/) + - WAVE_CACHE_ON=0 recommended +""" + +import argparse +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +# Add wave_lang to path. +wave_root = Path(__file__).parent.parent +sys.path.insert(0, str(wave_root)) + + +def find_waveasm_translate() -> str: + """Find the waveasm-translate binary.""" + env_path = os.environ.get("WAVEASM_TRANSLATE") + if env_path and Path(env_path).exists(): + return env_path + + candidates = [ + wave_root / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "build" / "tools" / "waveasm-translate" / "waveasm-translate", + wave_root / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "build" / "bin" / "waveasm-translate", + ] + for p in candidates: + if p.exists(): + return str(p) + + raise FileNotFoundError( + "waveasm-translate not found. Set WAVEASM_TRANSLATE env var or build it." + ) + + +def get_target() -> str: + """Get the target architecture.""" + return os.environ.get("WAVE_DEFAULT_ARCH", "gfx942") + + +def capture_kernel_mlir() -> tuple: + """ + Capture MLIR from a multi-wave GEMM kernel. + + Returns (mlir_text, workgroup_size). + """ + import wave_lang.kernel.lang as tkl + import wave_lang.kernel.wave as tkw + from wave_lang.kernel.lang.global_symbols import ( + GLOBAL_ADDRESS_SPACE, + SHARED_ADDRESS_SPACE, + ) + from wave_lang.kernel.wave.compile import WaveCompileOptions + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + from wave_lang.kernel._support.indexing import IndexingContext + from wave_lang.kernel.wave.compile import _trace_launchable_and_get_kernel_signature + from wave_lang.support.ir_imports import Context, Module, func_d + from wave_lang.kernel.wave.asm.mlir_analysis import ( + walk_ops_recursively, + should_skip_function, + ) + from wave_lang.kernel.wave.utils.compile_utils import canonicalize_module + + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + + # 4-wave config: 2x2 waves, 32x32 wave tiles. + block_m, block_n, wave_m, wave_n = 64, 64, 32, 32 + wave_size = 64 + mma_type = tkw.MMAType.F32_16x16x16_F16 + + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.TilingConstraint(K, BLOCK_K), + tkw.WaveConstraint(M, wave_m), + tkw.WaveConstraint(N, wave_n), + tkw.HardwareConstraint(threads_per_wave=wave_size, mma_type=mma_type), + ] + + @tkw.wave(constraints) + def gemm_kernel( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + b_reg = tkw.read(b) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c) + + m, n, k = 256, 256, 256 + block_k = 16 + + options = WaveCompileOptions( + subs={ + M: m, N: n, K: k, + BLOCK_M: block_m, BLOCK_N: block_n, BLOCK_K: block_k, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + backend="asm", + wave_runtime=True, + compile_to_mlir=False, + use_global_to_shared=False, + ) + options = set_default_run_config(options) + + # Capture MLIR via the same path as the e2e tests. + with IndexingContext() as idxc: + idxc.set_subs(options.subs) + gemm_kernel.initialize_wave_constraints() + gemm_kernel.initialize_symbolic_constraints() + gemm_kernel.initialize_workgroup_constraints() + + result = _trace_launchable_and_get_kernel_signature(gemm_kernel, options) + mb = result[0] + + if options.canonicalize: + canonicalize_module(mb.module_op) + + full_mlir = mb.module_op.get_asm(enable_debug_info=False) + + launch_info = options.kernel_launch_info + blocks = launch_info.blocks if launch_info.blocks else [64, 1, 1] + + # Extract func.func from stream wrapper. + with Context() as ctx: + ctx.allow_unregistered_dialects = True + module = Module.parse(full_mlir) + + for fn in walk_ops_recursively(module.operation): + if not isinstance(fn, func_d.FuncOp): + continue + if should_skip_function(fn): + continue + func_text = fn.get_asm(print_generic_op_form=True) + mlir_text = "module {\n" + func_text + "\n}\n" + return mlir_text, tuple(blocks) + + raise ValueError("No kernel function found in MLIR") + + +def run_waveasm_translate(mlir_text: str, workgroup_size: tuple, extra_flags: list = None) -> tuple: + """ + Run waveasm-translate with given flags. + + Returns (stdout, stderr, returncode). + """ + translate = find_waveasm_translate() + target = get_target() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".mlir", delete=False) as f: + f.write(mlir_text) + input_path = f.name + + try: + cmd = [ + translate, + f"--target={target}", + f"--workgroup-size-x={workgroup_size[0]}", + f"--workgroup-size-y={workgroup_size[1]}", + f"--workgroup-size-z={workgroup_size[2]}", + ] + if extra_flags: + cmd.extend(extra_flags) + cmd.append(input_path) + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + return result.stdout, result.stderr, result.returncode + finally: + os.unlink(input_path) + + +def run_pre_scheduling_pipeline(mlir_text: str, workgroup_size: tuple) -> str: + """ + Run waveasm-translate with only pre-scheduling passes. + + Produces WaveASM IR after: + TranslateFromMLIR -> ScopedCSE -> Peephole -> MemoryOffsetOpt -> Canonicalizer -> ScopedCSE + + But before: + LinearScan, InsertWaitcnt, HazardMitigation, EmitAssembly + """ + flags = [ + "--mlir-cse", + "--waveasm-scoped-cse", + "--waveasm-peephole", + "--waveasm-memory-offset-opt", + # Stop here — no regalloc, no waitcnt, no hazard, no assembly. + ] + stdout, stderr, rc = run_waveasm_translate(mlir_text, workgroup_size, flags) + if rc != 0: + print(f"waveasm-translate (pre-scheduling) failed:\n{stderr}", file=sys.stderr) + sys.exit(1) + return stdout + + +def run_full_pipeline(mlir_text: str, workgroup_size: tuple) -> str: + """Run the full pipeline and return assembly text.""" + flags = [ + "--mlir-cse", + "--waveasm-scoped-cse", + "--waveasm-peephole", + "--waveasm-memory-offset-opt", + "--waveasm-linear-scan", + "--max-vgprs=512", + "--max-agprs=512", + "--waveasm-insert-waitcnt", + "--waveasm-hazard-mitigation", + "--emit-assembly", + ] + stdout, stderr, rc = run_waveasm_translate(mlir_text, workgroup_size, flags) + if rc != 0: + print(f"Full pipeline failed:\n{stderr}", file=sys.stderr) + return "" + return stdout + + +def count_asm_metrics(asm_text: str) -> dict: + """Extract basic metrics from assembly text.""" + lines = asm_text.split("\n") + metrics = { + "total_instructions": 0, + "s_waitcnt": 0, + "s_nop": 0, + "mfma": 0, + "buffer_load": 0, + "ds_read": 0, + "ds_write": 0, + } + for line in lines: + stripped = line.strip() + if not stripped or stripped.startswith(("//", ";", ".")): + continue + if stripped.endswith(":"): + continue + metrics["total_instructions"] += 1 + if "s_waitcnt" in stripped: + metrics["s_waitcnt"] += 1 + if "s_nop" in stripped: + metrics["s_nop"] += 1 + if "mfma" in stripped: + metrics["mfma"] += 1 + if "buffer_load" in stripped: + metrics["buffer_load"] += 1 + if "ds_read" in stripped: + metrics["ds_read"] += 1 + if "ds_write" in stripped: + metrics["ds_write"] += 1 + + # Extract register counts from kernel descriptor. + for line in lines: + if ".amdhsa_next_free_vgpr" in line: + metrics["peak_vgpr"] = int(line.split()[-1]) + if ".amdhsa_next_free_sgpr" in line: + metrics["peak_sgpr"] = int(line.split()[-1]) + + return metrics + + +def main(): + parser = argparse.ArgumentParser( + description="Extract WaveASM IR at the Conductor scheduling stage." + ) + parser.add_argument( + "--dump-input-mlir", action="store_true", + help="Also dump the input MLIR (before waveasm-translate).", + ) + parser.add_argument( + "--output", "-o", type=str, default=None, + help="Write WaveASM IR to file instead of stdout.", + ) + parser.add_argument( + "--metrics", action="store_true", + help="Also run full pipeline and print baseline metrics.", + ) + args = parser.parse_args() + + print("Capturing kernel MLIR...", file=sys.stderr) + mlir_text, wg_size = capture_kernel_mlir() + print(f" workgroup_size: {wg_size}", file=sys.stderr) + print(f" input MLIR: {len(mlir_text)} chars", file=sys.stderr) + + if args.dump_input_mlir: + print("=== Input MLIR (before waveasm-translate) ===") + print(mlir_text) + print("=== End Input MLIR ===\n") + + print("Running pre-scheduling pipeline...", file=sys.stderr) + waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size) + print(f" WaveASM IR: {len(waveasm_ir)} chars", file=sys.stderr) + + op_count = sum(1 for line in waveasm_ir.split("\n") + if "waveasm." in line and not line.strip().startswith("//")) + print(f" WaveASM ops: {op_count}", file=sys.stderr) + + if args.output: + Path(args.output).write_text(waveasm_ir) + print(f" Written to: {args.output}", file=sys.stderr) + else: + print(waveasm_ir) + + if args.metrics: + print("\nRunning full pipeline for baseline metrics...", file=sys.stderr) + asm_text = run_full_pipeline(mlir_text, wg_size) + if asm_text: + metrics = count_asm_metrics(asm_text) + print("\n=== Baseline Metrics ===") + for k, v in metrics.items(): + print(f" {k}: {v}") + + +if __name__ == "__main__": + main() From 39d38b50f973053e9ea308845a3c5053a83fc704 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 00:54:22 +0100 Subject: [PATCH 14/49] conductor LLM integration Signed-off-by: Ivan Butygin --- conductor/README.md | 118 +++++++++++++++++ conductor/__init__.py | 4 + conductor/conductor.py | 85 ++++++++++-- conductor/extract_ir.py | 51 ++++++-- conductor/llm.py | 282 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 520 insertions(+), 20 deletions(-) create mode 100644 conductor/README.md create mode 100644 conductor/llm.py diff --git a/conductor/README.md b/conductor/README.md new file mode 100644 index 000000000..d66b4b5c8 --- /dev/null +++ b/conductor/README.md @@ -0,0 +1,118 @@ +# Conductor + +LLM-guided instruction scheduling for WaveASM. Takes pre-scheduling WaveASM IR, +lets an LLM reorder instructions via move/swap commands, then compiles and +measures the result. See +[CONDUCTOR_DESIGN.md](../wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md) +for the full design rationale. + +## Pipeline + +``` +Python frontend (GEMM kernel) + → Wave compiler → MLIR + → waveasm-translate (pre-scheduling: CSE, peephole, mem-offset-opt) + → Tag instructions (attach loc("tag_name") to each op) + → LLM loop: present IR + metrics → get move commands → apply → compile → measure + → Final assembly +``` + +## Modules + +| File | Purpose | +|------|---------| +| `extract_ir.py` | Captures MLIR from a Wave kernel, runs pre-scheduling pipeline, computes baseline metrics. | +| `conductor.py` | `Conductor` class: tag, apply_moves, compile_to_asm, evaluate, baseline. CLI entry point. | +| `llm.py` | OpenRouter client, prompt formatting, command parsing, iterative scheduling loop. | + +## Prerequisites + +- `waveasm-translate` and `waveasm-conductor` binaries (build from `wave_lang/kernel/wave/asm/wave_asm/`). +- Python `requests` package. +- `OPENROUTER_API_KEY` env var for LLM mode. + +## Usage + +All commands should be run from the wave project root with `WAVE_CACHE_ON=0`. + +### Baseline metrics (no LLM) + +```bash +python -m conductor.conductor --metrics +``` + +### Show tagged IR + +```bash +python -m conductor.conductor --tag-only +``` + +### Apply manual moves + +```bash +python -m conductor.conductor \ + --moves "swap buffer_load_dwordx4_0 v_mfma_f32_16x16x16_f16_0" \ + --metrics +``` + +### Read moves from a file + +```bash +python -m conductor.conductor --moves-file moves.txt --metrics +``` + +### LLM-guided scheduling + +```bash +OPENROUTER_API_KEY=sk-... python -m conductor.conductor \ + --llm --max-rounds 5 --model google/gemini-2.5-flash-preview --metrics +``` + +### Extract pre-scheduling IR only + +```bash +python -m conductor.extract_ir -o /tmp/conductor_ir.mlir +python -m conductor.extract_ir --metrics +``` + +## Move commands + +| Command | Example | +|---------|---------| +| `move TAG before TAG` | `move v_add_u32_1 before v_add_u32_0` | +| `move TAG after TAG` | `move buffer_load_dwordx4_0 after ds_read_b128_2` | +| `swap TAG TAG` | `swap v_mfma_f32_16x16x16_f16_0 ds_read_b128_1` | +| `done` | LLM is satisfied with current schedule. | + +## Environment variables + +| Variable | Default | Purpose | +|----------|---------|---------| +| `OPENROUTER_API_KEY` | (none) | Required for `--llm` mode. | +| `WAVE_CACHE_ON` | `1` | Set to `0` to disable caching during development. | +| `WAVE_DEFAULT_ARCH` | `gfx942` | Target GPU architecture. | +| `WAVEASM_TRANSLATE` | (auto-detect) | Override path to `waveasm-translate` binary. | +| `WAVEASM_CONDUCTOR` | (auto-detect) | Override path to `waveasm-conductor` binary. | + +## Programmatic usage + +```python +from conductor import Conductor +from conductor.extract_ir import capture_kernel_mlir, run_pre_scheduling_pipeline +from conductor.llm import run_scheduling_loop + +mlir_text, wg_size = capture_kernel_mlir() +waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size) + +c = Conductor(waveasm_ir, wg_size) + +# Baseline. +print(c.baseline()) + +# Manual moves. +print(c.evaluate(["swap buffer_load_dwordx4_0 v_mfma_f32_16x16x16_f16_0"])) + +# LLM loop. +result = run_scheduling_loop(c, max_rounds=5, model="google/gemini-2.5-flash-preview") +print(result["metrics"], result["commands"]) +``` diff --git a/conductor/__init__.py b/conductor/__init__.py index e7fad7bb1..5d61417b9 100644 --- a/conductor/__init__.py +++ b/conductor/__init__.py @@ -15,6 +15,7 @@ count_asm_metrics, capture_kernel_mlir, ) +from conductor.llm import run_scheduling_loop, parse_commands, format_prompt __all__ = [ "Conductor", @@ -25,4 +26,7 @@ "run_full_pipeline", "count_asm_metrics", "capture_kernel_mlir", + "run_scheduling_loop", + "parse_commands", + "format_prompt", ] diff --git a/conductor/conductor.py b/conductor/conductor.py index 0b920b0ad..953245980 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -31,7 +31,6 @@ sys.path.insert(0, str(wave_root)) from conductor.extract_ir import ( - find_waveasm_translate, run_waveasm_translate, count_asm_metrics, capture_kernel_mlir, @@ -46,8 +45,25 @@ def find_waveasm_conductor() -> str: return env_path candidates = [ - wave_root / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "build" / "tools" / "waveasm-conductor" / "waveasm-conductor", - wave_root / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "build" / "bin" / "waveasm-conductor", + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "build" + / "tools" + / "waveasm-conductor" + / "waveasm-conductor", + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "build" + / "bin" + / "waveasm-conductor", ] for p in candidates: if p.exists(): @@ -88,9 +104,7 @@ def apply_moves(self, tagged_ir: str, commands: list) -> str: header = "\n".join(f"// CONDUCTOR: {cmd}" for cmd in commands) full_input = header + "\n\n" + tagged_ir - with tempfile.NamedTemporaryFile( - mode="w", suffix=".mlir", delete=False - ) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".mlir", delete=False) as f: f.write(full_input) input_path = f.name @@ -148,21 +162,50 @@ def main(): description="Conductor: WaveASM instruction scheduling driver." ) parser.add_argument( - "--moves", nargs="*", default=None, + "--moves", + nargs="*", + default=None, help="Move commands (e.g. 'swap A B', 'move X after Y').", ) parser.add_argument( - "--moves-file", type=str, default=None, + "--moves-file", + type=str, + default=None, help="Read move commands from a file (one per line).", ) parser.add_argument( - "--metrics", action="store_true", + "--metrics", + action="store_true", help="Print assembly metrics after scheduling.", ) parser.add_argument( - "--tag-only", action="store_true", + "--tag-only", + action="store_true", help="Only show tagged IR, then exit.", ) + parser.add_argument( + "--llm", + action="store_true", + help="Run the LLM-guided scheduling loop.", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="OpenRouter model ID for --llm mode.", + ) + parser.add_argument( + "--max-rounds", + type=int, + default=5, + help="Maximum LLM scheduling rounds (default: 5).", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="LLM sampling temperature (default: 0.7).", + ) args = parser.parse_args() # Collect commands from both sources. @@ -190,6 +233,28 @@ def main(): print(conductor.tag()) return + if args.llm: + from conductor.llm import run_scheduling_loop, DEFAULT_MODEL + + model = args.model or DEFAULT_MODEL + print(f"Running LLM scheduling loop (model={model})...", file=sys.stderr) + result = run_scheduling_loop( + conductor, + max_rounds=args.max_rounds, + model=model, + temperature=args.temperature, + ) + print("\n=== LLM Scheduling Result ===") + print(f" rounds: {result['rounds']}") + print(f" commands: {result['commands']}") + print(" baseline:") + for k, v in result["baseline_metrics"].items(): + print(f" {k}: {v}") + print(" best:") + for k, v in result["metrics"].items(): + print(f" {k}: {v}") + return + if commands: print(f"Applying {len(commands)} move(s)...", file=sys.stderr) metrics = conductor.evaluate(commands) diff --git a/conductor/extract_ir.py b/conductor/extract_ir.py index be144f765..59f9d60a1 100644 --- a/conductor/extract_ir.py +++ b/conductor/extract_ir.py @@ -36,8 +36,25 @@ def find_waveasm_translate() -> str: return env_path candidates = [ - wave_root / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "build" / "tools" / "waveasm-translate" / "waveasm-translate", - wave_root / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "build" / "bin" / "waveasm-translate", + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "build" + / "tools" + / "waveasm-translate" + / "waveasm-translate", + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "build" + / "bin" + / "waveasm-translate", ] for p in candidates: if p.exists(): @@ -121,8 +138,12 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: options = WaveCompileOptions( subs={ - M: m, N: n, K: k, - BLOCK_M: block_m, BLOCK_N: block_n, BLOCK_K: block_k, + M: m, + N: n, + K: k, + BLOCK_M: block_m, + BLOCK_N: block_n, + BLOCK_K: block_k, ADDRESS_SPACE: SHARED_ADDRESS_SPACE, ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, }, @@ -169,7 +190,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: raise ValueError("No kernel function found in MLIR") -def run_waveasm_translate(mlir_text: str, workgroup_size: tuple, extra_flags: list = None) -> tuple: +def run_waveasm_translate( + mlir_text: str, workgroup_size: tuple, extra_flags: list = None +) -> tuple: """ Run waveasm-translate with given flags. @@ -292,15 +315,20 @@ def main(): description="Extract WaveASM IR at the Conductor scheduling stage." ) parser.add_argument( - "--dump-input-mlir", action="store_true", + "--dump-input-mlir", + action="store_true", help="Also dump the input MLIR (before waveasm-translate).", ) parser.add_argument( - "--output", "-o", type=str, default=None, + "--output", + "-o", + type=str, + default=None, help="Write WaveASM IR to file instead of stdout.", ) parser.add_argument( - "--metrics", action="store_true", + "--metrics", + action="store_true", help="Also run full pipeline and print baseline metrics.", ) args = parser.parse_args() @@ -319,8 +347,11 @@ def main(): waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size) print(f" WaveASM IR: {len(waveasm_ir)} chars", file=sys.stderr) - op_count = sum(1 for line in waveasm_ir.split("\n") - if "waveasm." in line and not line.strip().startswith("//")) + op_count = sum( + 1 + for line in waveasm_ir.split("\n") + if "waveasm." in line and not line.strip().startswith("//") + ) print(f" WaveASM ops: {op_count}", file=sys.stderr) if args.output: diff --git a/conductor/llm.py b/conductor/llm.py new file mode 100644 index 000000000..caf7a1d67 --- /dev/null +++ b/conductor/llm.py @@ -0,0 +1,282 @@ +"""OpenRouter LLM client and iterative scheduling loop for Conductor.""" + +import json +import os +import re +import sys +import time +from typing import Any + +import requests + +API_KEY: str = os.environ.get("OPENROUTER_API_KEY", "") +BASE_URL: str = "https://openrouter.ai/api/v1" +DEFAULT_MODEL: str = "google/gemini-2.5-flash-preview" + +_REQUEST_TIMEOUT = 120 +_MAX_RETRIES = 3 +_RETRY_BACKOFF = 2.0 + +# Valid move command pattern: move/swap with tag operands. +_MOVE_RE = re.compile( + r"^(move\s+\S+\s+(?:before|after)\s+\S+|swap\s+\S+\s+\S+|done)$", re.IGNORECASE +) + +SYSTEM_PROMPT = """\ +You are an expert GPU instruction scheduler for AMD CDNA/RDNA architectures. + +You will receive WaveASM MLIR IR with tagged instructions (loc("tag_name")). +Your job is to reorder instructions to: +1. Hide memory latency by interleaving loads with independent compute. +2. Reduce register pressure so the linear-scan allocator succeeds. +3. Minimize the number of s_waitcnt and s_nop instructions inserted by later passes. + +Key latencies: +- Global loads (buffer_load): ~100 cycles. +- LDS loads (ds_read): ~20 cycles. +- MFMA (16x16): ~32 cycles, (32x32): ~64 cycles. + +Rules: +- Issue move commands, one per line. +- Commands: "move TAG_A after TAG_B", "move TAG_A before TAG_B", "swap TAG_A TAG_B". +- Say "done" (alone on a line) when satisfied. +- Do NOT output anything else — no explanations, no markdown, just commands. +- Moves that break SSA dominance will be rejected; you will see the error. +- Pinned ops (s_endpgm, s_barrier, condition) cannot be moved. + +Strategy tips: +- Move global loads earlier to start memory fetches sooner. +- Interleave LDS reads between MFMAs to hide LDS latency behind MFMA execution. +- Keep address computations close to their consumers. +- Avoid clustering same-type ops (all loads together, all MFMAs together).\ +""" + + +def _chat( + messages: list[dict[str, Any]], + model: str, + temperature: float = 0.7, + max_tokens: int = 2048, +) -> str: + """Send a chat completion request to OpenRouter. Returns content text.""" + if not API_KEY: + raise RuntimeError( + "OPENROUTER_API_KEY not set. Export it before running the LLM loop." + ) + + payload = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + for attempt in range(_MAX_RETRIES): + try: + resp = requests.post( + f"{BASE_URL}/chat/completions", + headers={"Authorization": f"Bearer {API_KEY}"}, + json=payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + if resp.status_code >= 500: + if attempt < _MAX_RETRIES - 1: + wait = _RETRY_BACKOFF * (attempt + 1) + print( + f" [retry] server {resp.status_code}, waiting {wait:.0f}s...", + file=sys.stderr, + ) + time.sleep(wait) + continue + if resp.status_code >= 400: + raise RuntimeError( + f"OpenRouter API error {resp.status_code}: {resp.text}" + ) + + # Stream and accumulate content. + chunks: list[str] = [] + usage = None + for line in resp.iter_lines(): + if not line or not line.startswith(b"data: "): + continue + data = line[6:] + if data == b"[DONE]": + break + chunk = json.loads(data) + if "usage" in chunk: + usage = chunk["usage"] + choices = chunk.get("choices") + if not choices: + continue + delta = choices[0].get("delta", {}) + token = delta.get("content", "") + if token: + chunks.append(token) + + content = "".join(chunks) + if usage: + pt = usage.get("prompt_tokens", "?") + ct = usage.get("completion_tokens", "?") + print(f" [tokens] prompt={pt} completion={ct}", file=sys.stderr) + return content + + except ( + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + requests.exceptions.ChunkedEncodingError, + ): + if attempt < _MAX_RETRIES - 1: + wait = _RETRY_BACKOFF * (attempt + 1) + print( + f" [retry] connection error, waiting {wait:.0f}s...", + file=sys.stderr, + ) + time.sleep(wait) + else: + raise + + raise RuntimeError("Unreachable") + + +def parse_commands(text: str) -> list[str]: + """Parse move commands from LLM response. Returns list of command strings.""" + commands = [] + for line in text.strip().splitlines(): + line = line.strip() + if not line: + continue + # Strip markdown code fences if the model wraps output. + if line.startswith("```"): + continue + if _MOVE_RE.match(line): + commands.append(line) + return commands + + +def format_prompt( + tagged_ir: str, + metrics: dict, + round_num: int, + error: str | None = None, + target: str = "gfx942", +) -> str: + """Format the user prompt for a scheduling round.""" + parts = [ + f"=== WaveASM Scheduling Round {round_num} ===", + f"TARGET: {target} (wave64, 512 vgpr, 106 sgpr, 512 agpr)", + "LATENCY: vmem=100, lds=20, mfma_16x16=32, mfma_32x32=64", + "", + "--- IR (tagged) ---", + tagged_ir.strip(), + "", + "--- Metrics ---", + ] + for k, v in metrics.items(): + parts.append(f" {k}: {v}") + + if error: + parts.extend(["", "--- Error from previous round ---", error]) + + parts.extend( + [ + "", + "GOAL: Minimize register pressure and hide memory latency.", + "Respond with move commands, one per line.", + ] + ) + return "\n".join(parts) + + +def run_scheduling_loop( + conductor, + max_rounds: int = 5, + model: str = DEFAULT_MODEL, + temperature: float = 0.7, + verbose: bool = True, +) -> dict: + """ + Run the iterative LLM scheduling loop. + + Returns dict with keys: metrics, commands, rounds, baseline_metrics. + """ + + def log(msg: str) -> None: + if verbose: + print(msg, file=sys.stderr) + + log("Computing baseline metrics...") + baseline = conductor.baseline() + log(f" baseline: {baseline}") + + tagged_ir = conductor.tag() + log(f" tagged IR: {len(tagged_ir)} chars") + + best_metrics = dict(baseline) + best_commands: list[str] = [] + messages: list[dict[str, Any]] = [{"role": "system", "content": SYSTEM_PROMPT}] + error: str | None = None + + for round_num in range(1, max_rounds + 1): + log(f"\n--- Round {round_num}/{max_rounds} ---") + + prompt = format_prompt( + tagged_ir, best_metrics, round_num, error=error, target=conductor.target + ) + messages.append({"role": "user", "content": prompt}) + + log(" Querying LLM...") + response = _chat(messages, model=model, temperature=temperature) + messages.append({"role": "assistant", "content": response}) + log(f" Response:\n{response}") + + commands = parse_commands(response) + if not commands: + log(" No valid commands parsed, stopping.") + break + + if len(commands) == 1 and commands[0].lower() == "done": + log(" LLM says done.") + break + + error = None + try: + metrics = conductor.evaluate(commands) + log(f" metrics: {metrics}") + + if _is_better(metrics, best_metrics): + log(" Improvement found!") + best_metrics = metrics + best_commands = commands + else: + log(" No improvement, reverting.") + error = ( + f"Round {round_num} regressed metrics. " + f"Previous best: {best_metrics}, this round: {metrics}. " + "Moves reverted." + ) + except RuntimeError as e: + error_msg = str(e) + log(f" Error: {error_msg}") + error = error_msg + + return { + "metrics": best_metrics, + "commands": best_commands, + "rounds": round_num, + "baseline_metrics": baseline, + } + + +def _is_better(new: dict, old: dict) -> bool: + """Compare metrics: lower VGPRs > fewer waitcnts > fewer nops > fewer instructions.""" + for key in ("peak_vgpr", "s_waitcnt", "s_nop", "total_instructions"): + nv = new.get(key, 0) + ov = old.get(key, 0) + if nv < ov: + return True + if nv > ov: + return False + return False From 565d2f0f65744d789ad286a8e062abb8796ac350 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 00:58:23 +0100 Subject: [PATCH 15/49] model and remove done Signed-off-by: Ivan Butygin --- conductor/README.md | 1 - conductor/llm.py | 9 ++------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/conductor/README.md b/conductor/README.md index d66b4b5c8..b9b7338da 100644 --- a/conductor/README.md +++ b/conductor/README.md @@ -82,7 +82,6 @@ python -m conductor.extract_ir --metrics | `move TAG before TAG` | `move v_add_u32_1 before v_add_u32_0` | | `move TAG after TAG` | `move buffer_load_dwordx4_0 after ds_read_b128_2` | | `swap TAG TAG` | `swap v_mfma_f32_16x16x16_f16_0 ds_read_b128_1` | -| `done` | LLM is satisfied with current schedule. | ## Environment variables diff --git a/conductor/llm.py b/conductor/llm.py index caf7a1d67..800020c22 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -11,7 +11,7 @@ API_KEY: str = os.environ.get("OPENROUTER_API_KEY", "") BASE_URL: str = "https://openrouter.ai/api/v1" -DEFAULT_MODEL: str = "google/gemini-2.5-flash-preview" +DEFAULT_MODEL: str = "deepseek/deepseek-v3.2" _REQUEST_TIMEOUT = 120 _MAX_RETRIES = 3 @@ -19,7 +19,7 @@ # Valid move command pattern: move/swap with tag operands. _MOVE_RE = re.compile( - r"^(move\s+\S+\s+(?:before|after)\s+\S+|swap\s+\S+\s+\S+|done)$", re.IGNORECASE + r"^(move\s+\S+\s+(?:before|after)\s+\S+|swap\s+\S+\s+\S+)$", re.IGNORECASE ) SYSTEM_PROMPT = """\ @@ -39,7 +39,6 @@ Rules: - Issue move commands, one per line. - Commands: "move TAG_A after TAG_B", "move TAG_A before TAG_B", "swap TAG_A TAG_B". -- Say "done" (alone on a line) when satisfied. - Do NOT output anything else — no explanations, no markdown, just commands. - Moves that break SSA dominance will be rejected; you will see the error. - Pinned ops (s_endpgm, s_barrier, condition) cannot be moved. @@ -237,10 +236,6 @@ def log(msg: str) -> None: log(" No valid commands parsed, stopping.") break - if len(commands) == 1 and commands[0].lower() == "done": - log(" LLM says done.") - break - error = None try: metrics = conductor.evaluate(commands) From d4ef635672c795929df74d130a2267b6c7eca9bf Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 01:07:28 +0100 Subject: [PATCH 16/49] logging --- conductor/conductor.py | 7 +++ conductor/llm.py | 103 ++++++++++++++++++++++++++++------------- 2 files changed, 79 insertions(+), 31 deletions(-) diff --git a/conductor/conductor.py b/conductor/conductor.py index 953245980..e2233e374 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -206,6 +206,12 @@ def main(): default=0.7, help="LLM sampling temperature (default: 0.7).", ) + parser.add_argument( + "--reasoning-effort", + type=str, + default="high", + help="Reasoning effort for models that support it (default: high).", + ) args = parser.parse_args() # Collect commands from both sources. @@ -243,6 +249,7 @@ def main(): max_rounds=args.max_rounds, model=model, temperature=args.temperature, + reasoning_effort=args.reasoning_effort, ) print("\n=== LLM Scheduling Result ===") print(f" rounds: {result['rounds']}") diff --git a/conductor/llm.py b/conductor/llm.py index 800020c22..3c68c6e8f 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -5,6 +5,7 @@ import re import sys import time +from collections.abc import Callable from typing import Any import requests @@ -17,6 +18,16 @@ _MAX_RETRIES = 3 _RETRY_BACKOFF = 2.0 + +def _default_log(msg: str) -> None: + """Default logger: print to stderr without trailing newline.""" + print(msg, file=sys.stderr, end="", flush=True) + + +def _noop_log(_msg: str) -> None: + pass + + # Valid move command pattern: move/swap with tag operands. _MOVE_RE = re.compile( r"^(move\s+\S+\s+(?:before|after)\s+\S+|swap\s+\S+\s+\S+)$", re.IGNORECASE @@ -56,6 +67,8 @@ def _chat( model: str, temperature: float = 0.7, max_tokens: int = 2048, + reasoning_effort: str | None = None, + log: Callable[[str], None] = _default_log, ) -> str: """Send a chat completion request to OpenRouter. Returns content text.""" if not API_KEY: @@ -63,7 +76,7 @@ def _chat( "OPENROUTER_API_KEY not set. Export it before running the LLM loop." ) - payload = { + payload: dict[str, Any] = { "model": model, "messages": messages, "temperature": temperature, @@ -71,6 +84,8 @@ def _chat( "stream": True, "stream_options": {"include_usage": True}, } + if reasoning_effort is not None: + payload["reasoning"] = {"enabled": True, "effort": reasoning_effort} for attempt in range(_MAX_RETRIES): try: @@ -84,9 +99,8 @@ def _chat( if resp.status_code >= 500: if attempt < _MAX_RETRIES - 1: wait = _RETRY_BACKOFF * (attempt + 1) - print( - f" [retry] server {resp.status_code}, waiting {wait:.0f}s...", - file=sys.stderr, + log( + f"\n [retry] server {resp.status_code}, waiting {wait:.0f}s...\n" ) time.sleep(wait) continue @@ -95,9 +109,12 @@ def _chat( f"OpenRouter API error {resp.status_code}: {resp.text}" ) - # Stream and accumulate content. - chunks: list[str] = [] + # Stream and accumulate content + reasoning. + content_chunks: list[str] = [] + reasoning_chunks: list[str] = [] usage = None + in_reasoning = False + for line in resp.iter_lines(): if not line or not line.startswith(b"data: "): continue @@ -111,15 +128,40 @@ def _chat( if not choices: continue delta = choices[0].get("delta", {}) + + # Reasoning tokens (two OpenRouter formats). + for detail in delta.get("reasoning_details", []): + text = detail.get("text", "") + if text: + if not in_reasoning: + log("\n [thinking] ") + in_reasoning = True + log(text) + reasoning_chunks.append(text) + rc = delta.get("reasoning_content", "") + if rc: + if not in_reasoning: + log("\n [thinking] ") + in_reasoning = True + log(rc) + reasoning_chunks.append(rc) + + # Content tokens. token = delta.get("content", "") if token: - chunks.append(token) + if in_reasoning: + log("\n [/thinking]\n") + in_reasoning = False + content_chunks.append(token) + + if in_reasoning: + log("\n [/thinking]\n") - content = "".join(chunks) + content = "".join(content_chunks) if usage: pt = usage.get("prompt_tokens", "?") ct = usage.get("completion_tokens", "?") - print(f" [tokens] prompt={pt} completion={ct}", file=sys.stderr) + log(f"\n [tokens] prompt={pt} completion={ct}\n") return content except ( @@ -129,10 +171,7 @@ def _chat( ): if attempt < _MAX_RETRIES - 1: wait = _RETRY_BACKOFF * (attempt + 1) - print( - f" [retry] connection error, waiting {wait:.0f}s...", - file=sys.stderr, - ) + log(f"\n [retry] connection error, waiting {wait:.0f}s...\n") time.sleep(wait) else: raise @@ -194,24 +233,20 @@ def run_scheduling_loop( max_rounds: int = 5, model: str = DEFAULT_MODEL, temperature: float = 0.7, - verbose: bool = True, + reasoning_effort: str | None = "high", + log: Callable[[str], None] = _default_log, ) -> dict: """ Run the iterative LLM scheduling loop. Returns dict with keys: metrics, commands, rounds, baseline_metrics. """ - - def log(msg: str) -> None: - if verbose: - print(msg, file=sys.stderr) - - log("Computing baseline metrics...") + log("Computing baseline metrics...\n") baseline = conductor.baseline() - log(f" baseline: {baseline}") + log(f" baseline: {baseline}\n") tagged_ir = conductor.tag() - log(f" tagged IR: {len(tagged_ir)} chars") + log(f" tagged IR: {len(tagged_ir)} chars\n") best_metrics = dict(baseline) best_commands: list[str] = [] @@ -219,34 +254,40 @@ def log(msg: str) -> None: error: str | None = None for round_num in range(1, max_rounds + 1): - log(f"\n--- Round {round_num}/{max_rounds} ---") + log(f"\n--- Round {round_num}/{max_rounds} ---\n") prompt = format_prompt( tagged_ir, best_metrics, round_num, error=error, target=conductor.target ) messages.append({"role": "user", "content": prompt}) - log(" Querying LLM...") - response = _chat(messages, model=model, temperature=temperature) + log(" Querying LLM...\n") + response = _chat( + messages, + model=model, + temperature=temperature, + reasoning_effort=reasoning_effort, + log=log, + ) messages.append({"role": "assistant", "content": response}) - log(f" Response:\n{response}") + log(f" Response:\n{response}\n") commands = parse_commands(response) if not commands: - log(" No valid commands parsed, stopping.") + log(" No valid commands parsed, stopping.\n") break error = None try: metrics = conductor.evaluate(commands) - log(f" metrics: {metrics}") + log(f" metrics: {metrics}\n") if _is_better(metrics, best_metrics): - log(" Improvement found!") + log(" Improvement found!\n") best_metrics = metrics best_commands = commands else: - log(" No improvement, reverting.") + log(" No improvement, reverting.\n") error = ( f"Round {round_num} regressed metrics. " f"Previous best: {best_metrics}, this round: {metrics}. " @@ -254,7 +295,7 @@ def log(msg: str) -> None: ) except RuntimeError as e: error_msg = str(e) - log(f" Error: {error_msg}") + log(f" Error: {error_msg}\n") error = error_msg return { From 7c5fd97d6a5da1b979771a1692db675db638f01c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 01:22:16 +0100 Subject: [PATCH 17/49] prompt --- conductor/llm.py | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index 3c68c6e8f..8697f3dd2 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -37,28 +37,22 @@ def _noop_log(_msg: str) -> None: You are an expert GPU instruction scheduler for AMD CDNA/RDNA architectures. You will receive WaveASM MLIR IR with tagged instructions (loc("tag_name")). -Your job is to reorder instructions to: -1. Hide memory latency by interleaving loads with independent compute. -2. Reduce register pressure so the linear-scan allocator succeeds. -3. Minimize the number of s_waitcnt and s_nop instructions inserted by later passes. - -Key latencies: -- Global loads (buffer_load): ~100 cycles. -- LDS loads (ds_read): ~20 cycles. -- MFMA (16x16): ~32 cycles, (32x32): ~64 cycles. - -Rules: -- Issue move commands, one per line. -- Commands: "move TAG_A after TAG_B", "move TAG_A before TAG_B", "swap TAG_A TAG_B". -- Do NOT output anything else — no explanations, no markdown, just commands. -- Moves that break SSA dominance will be rejected; you will see the error. +Your job is to reorder instructions to hide memory latency and reduce register pressure. + +Key latencies: global loads ~100 cycles, LDS loads ~20 cycles, MFMA 16x16 ~32 cycles. + +Commands (one per line, nothing else): + move TAG_A after TAG_B + move TAG_A before TAG_B + swap TAG_A TAG_B + +Constraints: +- Moves that break SSA dominance will be rejected. - Pinned ops (s_endpgm, s_barrier, condition) cannot be moved. -Strategy tips: -- Move global loads earlier to start memory fetches sooner. -- Interleave LDS reads between MFMAs to hide LDS latency behind MFMA execution. -- Keep address computations close to their consumers. -- Avoid clustering same-type ops (all loads together, all MFMAs together).\ +IMPORTANT: Work incrementally. Issue 1-3 moves per round, observe the metrics, \ +then adjust. Do not try to solve everything at once. Each round you will see \ +updated metrics so you can evaluate what worked.\ """ @@ -85,7 +79,10 @@ def _chat( "stream_options": {"include_usage": True}, } if reasoning_effort is not None: - payload["reasoning"] = {"enabled": True, "effort": reasoning_effort} + payload["reasoning"] = { + "enabled": True, + "effort": reasoning_effort, + } for attempt in range(_MAX_RETRIES): try: @@ -233,7 +230,7 @@ def run_scheduling_loop( max_rounds: int = 5, model: str = DEFAULT_MODEL, temperature: float = 0.7, - reasoning_effort: str | None = "high", + reasoning_effort: str | None = "medium", log: Callable[[str], None] = _default_log, ) -> dict: """ From a420f18d663e3779bd54d81a6d5b367aa20088d3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 01:31:02 +0100 Subject: [PATCH 18/49] tool use Signed-off-by: Ivan Butygin --- conductor/__init__.py | 4 +- conductor/llm.py | 225 ++++++++++++++++++++++++++---------------- 2 files changed, 140 insertions(+), 89 deletions(-) diff --git a/conductor/__init__.py b/conductor/__init__.py index 5d61417b9..03f2e0f26 100644 --- a/conductor/__init__.py +++ b/conductor/__init__.py @@ -15,7 +15,7 @@ count_asm_metrics, capture_kernel_mlir, ) -from conductor.llm import run_scheduling_loop, parse_commands, format_prompt +from conductor.llm import run_scheduling_loop __all__ = [ "Conductor", @@ -27,6 +27,4 @@ "count_asm_metrics", "capture_kernel_mlir", "run_scheduling_loop", - "parse_commands", - "format_prompt", ] diff --git a/conductor/llm.py b/conductor/llm.py index 8697f3dd2..cc5c9ff8d 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -2,7 +2,6 @@ import json import os -import re import sys import time from collections.abc import Callable @@ -18,6 +17,8 @@ _MAX_RETRIES = 3 _RETRY_BACKOFF = 2.0 +Message = dict[str, Any] + def _default_log(msg: str) -> None: """Default logger: print to stderr without trailing newline.""" @@ -28,11 +29,6 @@ def _noop_log(_msg: str) -> None: pass -# Valid move command pattern: move/swap with tag operands. -_MOVE_RE = re.compile( - r"^(move\s+\S+\s+(?:before|after)\s+\S+|swap\s+\S+\s+\S+)$", re.IGNORECASE -) - SYSTEM_PROMPT = """\ You are an expert GPU instruction scheduler for AMD CDNA/RDNA architectures. @@ -41,30 +37,61 @@ def _noop_log(_msg: str) -> None: Key latencies: global loads ~100 cycles, LDS loads ~20 cycles, MFMA 16x16 ~32 cycles. -Commands (one per line, nothing else): - move TAG_A after TAG_B - move TAG_A before TAG_B - swap TAG_A TAG_B +You have an `evaluate_moves` tool. Call it with a list of move command strings. +Each command is one of: + "move TAG_A after TAG_B" + "move TAG_A before TAG_B" + "swap TAG_A TAG_B" + +The tool will apply the moves, compile, and return metrics. Constraints: - Moves that break SSA dominance will be rejected. - Pinned ops (s_endpgm, s_barrier, condition) cannot be moved. -IMPORTANT: Work incrementally. Issue 1-3 moves per round, observe the metrics, \ -then adjust. Do not try to solve everything at once. Each round you will see \ -updated metrics so you can evaluate what worked.\ +Work incrementally: try 1-3 moves per tool call, read the resulting metrics, \ +then decide your next moves. You can call the tool multiple times.\ """ +TOOLS = [ + { + "type": "function", + "function": { + "name": "evaluate_moves", + "description": ( + "Apply a list of move/swap commands to the tagged IR, " + "compile through the post-scheduling pipeline, and return " + "assembly metrics. Commands are applied in order." + ), + "parameters": { + "type": "object", + "properties": { + "moves": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "List of move commands, e.g. " + '["move tag_A after tag_B", "swap tag_C tag_D"].' + ), + } + }, + "required": ["moves"], + }, + }, + } +] + def _chat( - messages: list[dict[str, Any]], + messages: list[Message], model: str, temperature: float = 0.7, max_tokens: int = 2048, reasoning_effort: str | None = None, + tools: list[dict] | None = None, log: Callable[[str], None] = _default_log, -) -> str: - """Send a chat completion request to OpenRouter. Returns content text.""" +) -> Message: + """Send a streaming chat completion request. Returns the full response message.""" if not API_KEY: raise RuntimeError( "OPENROUTER_API_KEY not set. Export it before running the LLM loop." @@ -79,10 +106,9 @@ def _chat( "stream_options": {"include_usage": True}, } if reasoning_effort is not None: - payload["reasoning"] = { - "enabled": True, - "effort": reasoning_effort, - } + payload["reasoning"] = {"enabled": True, "effort": reasoning_effort} + if tools: + payload["tools"] = tools for attempt in range(_MAX_RETRIES): try: @@ -106,9 +132,10 @@ def _chat( f"OpenRouter API error {resp.status_code}: {resp.text}" ) - # Stream and accumulate content + reasoning. + # Stream and accumulate content, reasoning, and tool calls. content_chunks: list[str] = [] reasoning_chunks: list[str] = [] + tool_calls_by_index: dict[int, dict[str, Any]] = {} usage = None in_reasoning = False @@ -151,15 +178,35 @@ def _chat( in_reasoning = False content_chunks.append(token) + # Tool call deltas. + for tc_delta in delta.get("tool_calls", []): + idx = tc_delta["index"] + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": tc_delta.get("id", ""), + "type": "function", + "function": {"name": "", "arguments": ""}, + } + tc = tool_calls_by_index[idx] + func_delta = tc_delta.get("function", {}) + if func_delta.get("name"): + tc["function"]["name"] += func_delta["name"] + if func_delta.get("arguments"): + tc["function"]["arguments"] += func_delta["arguments"] + if in_reasoning: log("\n [/thinking]\n") - content = "".join(content_chunks) + result: Message = {"role": "assistant", "content": "".join(content_chunks)} + if tool_calls_by_index: + result["tool_calls"] = [ + tool_calls_by_index[i] for i in sorted(tool_calls_by_index) + ] if usage: pt = usage.get("prompt_tokens", "?") ct = usage.get("completion_tokens", "?") log(f"\n [tokens] prompt={pt} completion={ct}\n") - return content + return result except ( requests.exceptions.ConnectionError, @@ -176,50 +223,28 @@ def _chat( raise RuntimeError("Unreachable") -def parse_commands(text: str) -> list[str]: - """Parse move commands from LLM response. Returns list of command strings.""" - commands = [] - for line in text.strip().splitlines(): - line = line.strip() - if not line: - continue - # Strip markdown code fences if the model wraps output. - if line.startswith("```"): - continue - if _MOVE_RE.match(line): - commands.append(line) - return commands - - -def format_prompt( +def format_initial_prompt( tagged_ir: str, - metrics: dict, - round_num: int, - error: str | None = None, + baseline_metrics: dict, target: str = "gfx942", ) -> str: - """Format the user prompt for a scheduling round.""" + """Format the initial user message with IR and baseline metrics.""" parts = [ - f"=== WaveASM Scheduling Round {round_num} ===", f"TARGET: {target} (wave64, 512 vgpr, 106 sgpr, 512 agpr)", "LATENCY: vmem=100, lds=20, mfma_16x16=32, mfma_32x32=64", "", - "--- IR (tagged) ---", + "--- Tagged IR ---", tagged_ir.strip(), "", - "--- Metrics ---", + "--- Baseline Metrics ---", ] - for k, v in metrics.items(): + for k, v in baseline_metrics.items(): parts.append(f" {k}: {v}") - - if error: - parts.extend(["", "--- Error from previous round ---", error]) - parts.extend( [ "", "GOAL: Minimize register pressure and hide memory latency.", - "Respond with move commands, one per line.", + "Use the evaluate_moves tool to try reorderings.", ] ) return "\n".join(parts) @@ -227,14 +252,18 @@ def format_prompt( def run_scheduling_loop( conductor, - max_rounds: int = 5, + max_rounds: int = 10, model: str = DEFAULT_MODEL, temperature: float = 0.7, reasoning_effort: str | None = "medium", log: Callable[[str], None] = _default_log, ) -> dict: """ - Run the iterative LLM scheduling loop. + Run the iterative LLM scheduling loop with tool use. + + The LLM reasons in natural language and calls evaluate_moves to test + scheduling ideas. Conversation history (including reasoning) is preserved + across rounds. Returns dict with keys: metrics, commands, rounds, baseline_metrics. """ @@ -247,53 +276,77 @@ def run_scheduling_loop( best_metrics = dict(baseline) best_commands: list[str] = [] - messages: list[dict[str, Any]] = [{"role": "system", "content": SYSTEM_PROMPT}] - error: str | None = None + + messages: list[Message] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": format_initial_prompt( + tagged_ir, baseline, target=conductor.target + ), + }, + ] for round_num in range(1, max_rounds + 1): log(f"\n--- Round {round_num}/{max_rounds} ---\n") - prompt = format_prompt( - tagged_ir, best_metrics, round_num, error=error, target=conductor.target - ) - messages.append({"role": "user", "content": prompt}) - - log(" Querying LLM...\n") response = _chat( messages, model=model, temperature=temperature, reasoning_effort=reasoning_effort, + tools=TOOLS, log=log, ) - messages.append({"role": "assistant", "content": response}) - log(f" Response:\n{response}\n") + messages.append(response) - commands = parse_commands(response) - if not commands: - log(" No valid commands parsed, stopping.\n") - break + content = response.get("content", "") + if content: + log(f" [model] {content}\n") - error = None - try: - metrics = conductor.evaluate(commands) - log(f" metrics: {metrics}\n") + tool_calls = response.get("tool_calls") + if not tool_calls: + log(" No tool call, model is done.\n") + break - if _is_better(metrics, best_metrics): - log(" Improvement found!\n") - best_metrics = metrics - best_commands = commands + # Process each tool call. + for tc in tool_calls: + name = tc["function"]["name"] + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + tool_result = {"error": "Malformed JSON arguments."} + log(f" [tool] {name}: malformed args\n") else: - log(" No improvement, reverting.\n") - error = ( - f"Round {round_num} regressed metrics. " - f"Previous best: {best_metrics}, this round: {metrics}. " - "Moves reverted." - ) - except RuntimeError as e: - error_msg = str(e) - log(f" Error: {error_msg}\n") - error = error_msg + if name == "evaluate_moves": + moves = args.get("moves", []) + log(f" [tool] evaluate_moves({moves})\n") + try: + metrics = conductor.evaluate(moves) + log(f" [result] {metrics}\n") + if _is_better(metrics, best_metrics): + log(" Improvement!\n") + best_metrics = metrics + best_commands = moves + tool_result = { + "metrics": metrics, + "improved": _is_better(metrics, best_metrics) + or metrics == best_metrics, + } + except RuntimeError as e: + tool_result = {"error": str(e)} + log(f" [error] {e}\n") + else: + tool_result = {"error": f"Unknown tool: {name}"} + log(f" [tool] unknown: {name}\n") + + messages.append( + { + "role": "tool", + "tool_call_id": tc["id"], + "content": json.dumps(tool_result), + } + ) return { "metrics": best_metrics, From 8fbe77e646f6d97a66b60bde4dc67cdc1b7346ba Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 01:37:19 +0100 Subject: [PATCH 19/49] usage Signed-off-by: Ivan Butygin --- conductor/__init__.py | 4 +- conductor/conductor.py | 6 ++ conductor/llm.py | 228 ++++++++++++++++++++++++++++++----------- 3 files changed, 177 insertions(+), 61 deletions(-) diff --git a/conductor/__init__.py b/conductor/__init__.py index 03f2e0f26..48d619d8b 100644 --- a/conductor/__init__.py +++ b/conductor/__init__.py @@ -15,7 +15,7 @@ count_asm_metrics, capture_kernel_mlir, ) -from conductor.llm import run_scheduling_loop +from conductor.llm import run_scheduling_loop, Stats, Counters __all__ = [ "Conductor", @@ -27,4 +27,6 @@ "count_asm_metrics", "capture_kernel_mlir", "run_scheduling_loop", + "Stats", + "Counters", ] diff --git a/conductor/conductor.py b/conductor/conductor.py index e2233e374..84380bb96 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -260,6 +260,12 @@ def main(): print(" best:") for k, v in result["metrics"].items(): print(f" {k}: {v}") + usage = result.get("usage") + if usage: + print( + f" tokens: {usage.tokens} (in={usage.input_tokens} out={usage.output_tokens})" + ) + print(f" cost: ${usage.cost:.4f}") return if commands: diff --git a/conductor/llm.py b/conductor/llm.py index cc5c9ff8d..2e17a4236 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -3,8 +3,12 @@ import json import os import sys +import threading import time +from collections import defaultdict from collections.abc import Callable +from dataclasses import dataclass +from types import TracebackType from typing import Any import requests @@ -20,6 +24,100 @@ Message = dict[str, Any] +# --- Monotonic API usage counters (thread-safe, per-model). --- + + +@dataclass +class Counters: + """API usage snapshot.""" + + tokens: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + cost: float = 0.0 + + +_counters_lock = threading.Lock() +_counters: defaultdict[str, Counters] = defaultdict(Counters) + + +def _record_usage(model: str, usage: dict[str, Any]) -> None: + """Accumulate token and cost from an API response.""" + tokens = int(usage.get("total_tokens", 0)) + input_tokens = int(usage.get("prompt_tokens", 0)) + output_tokens = int(usage.get("completion_tokens", 0)) + cost = usage.get("cost") + with _counters_lock: + c = _counters[model] + c.tokens += tokens + c.input_tokens += input_tokens + c.output_tokens += output_tokens + if cost is not None: + c.cost += float(cost) + + +class Stats: + """Context manager that captures API usage over a scope. + + Snapshots the monotonic per-model counters on entry. The ``counters`` + property returns an aggregate delta; ``per_model`` returns per-model deltas. + """ + + def __init__(self) -> None: + self._start: dict[str, Counters] = {} + + def __enter__(self) -> "Stats": + with _counters_lock: + self._start = { + m: Counters(c.tokens, c.input_tokens, c.output_tokens, c.cost) + for m, c in _counters.items() + } + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + pass + + @property + def counters(self) -> Counters: + """Aggregate delta across all models since entering the context.""" + with _counters_lock: + tok = sum(c.tokens for c in _counters.values()) - sum( + c.tokens for c in self._start.values() + ) + inp = sum(c.input_tokens for c in _counters.values()) - sum( + c.input_tokens for c in self._start.values() + ) + out = sum(c.output_tokens for c in _counters.values()) - sum( + c.output_tokens for c in self._start.values() + ) + cst = sum(c.cost for c in _counters.values()) - sum( + c.cost for c in self._start.values() + ) + return Counters(tokens=tok, input_tokens=inp, output_tokens=out, cost=cst) + + @property + def per_model(self) -> defaultdict[str, Counters]: + """Per-model deltas since entering the context.""" + with _counters_lock: + result: defaultdict[str, Counters] = defaultdict(Counters) + for model, current in _counters.items(): + start = self._start.get(model, Counters()) + delta = Counters( + tokens=current.tokens - start.tokens, + input_tokens=current.input_tokens - start.input_tokens, + output_tokens=current.output_tokens - start.output_tokens, + cost=current.cost - start.cost, + ) + if delta.tokens or delta.cost: + result[model] = delta + return result + + def _default_log(msg: str) -> None: """Default logger: print to stderr without trailing newline.""" print(msg, file=sys.stderr, end="", flush=True) @@ -203,6 +301,7 @@ def _chat( tool_calls_by_index[i] for i in sorted(tool_calls_by_index) ] if usage: + _record_usage(model, usage) pt = usage.get("prompt_tokens", "?") ct = usage.get("completion_tokens", "?") log(f"\n [tokens] prompt={pt} completion={ct}\n") @@ -265,7 +364,7 @@ def run_scheduling_loop( scheduling ideas. Conversation history (including reasoning) is preserved across rounds. - Returns dict with keys: metrics, commands, rounds, baseline_metrics. + Returns dict with keys: metrics, commands, rounds, baseline_metrics, usage. """ log("Computing baseline metrics...\n") baseline = conductor.baseline() @@ -287,72 +386,81 @@ def run_scheduling_loop( }, ] - for round_num in range(1, max_rounds + 1): - log(f"\n--- Round {round_num}/{max_rounds} ---\n") - - response = _chat( - messages, - model=model, - temperature=temperature, - reasoning_effort=reasoning_effort, - tools=TOOLS, - log=log, - ) - messages.append(response) - - content = response.get("content", "") - if content: - log(f" [model] {content}\n") - - tool_calls = response.get("tool_calls") - if not tool_calls: - log(" No tool call, model is done.\n") - break - - # Process each tool call. - for tc in tool_calls: - name = tc["function"]["name"] - try: - args = json.loads(tc["function"]["arguments"]) - except json.JSONDecodeError: - tool_result = {"error": "Malformed JSON arguments."} - log(f" [tool] {name}: malformed args\n") - else: - if name == "evaluate_moves": - moves = args.get("moves", []) - log(f" [tool] evaluate_moves({moves})\n") - try: - metrics = conductor.evaluate(moves) - log(f" [result] {metrics}\n") - if _is_better(metrics, best_metrics): - log(" Improvement!\n") - best_metrics = metrics - best_commands = moves - tool_result = { - "metrics": metrics, - "improved": _is_better(metrics, best_metrics) - or metrics == best_metrics, - } - except RuntimeError as e: - tool_result = {"error": str(e)} - log(f" [error] {e}\n") - else: - tool_result = {"error": f"Unknown tool: {name}"} - log(f" [tool] unknown: {name}\n") - - messages.append( - { - "role": "tool", - "tool_call_id": tc["id"], - "content": json.dumps(tool_result), - } + with Stats() as stats: + for round_num in range(1, max_rounds + 1): + log(f"\n--- Round {round_num}/{max_rounds} ---\n") + + response = _chat( + messages, + model=model, + temperature=temperature, + reasoning_effort=reasoning_effort, + tools=TOOLS, + log=log, ) + messages.append(response) + + content = response.get("content", "") + if content: + log(f" [model] {content}\n") + + tool_calls = response.get("tool_calls") + if not tool_calls: + log(" No tool call, model is done.\n") + break + + # Process each tool call. + for tc in tool_calls: + name = tc["function"]["name"] + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + tool_result = {"error": "Malformed JSON arguments."} + log(f" [tool] {name}: malformed args\n") + else: + if name == "evaluate_moves": + moves = args.get("moves", []) + log(f" [tool] evaluate_moves({moves})\n") + try: + metrics = conductor.evaluate(moves) + log(f" [result] {metrics}\n") + if _is_better(metrics, best_metrics): + log(" Improvement!\n") + best_metrics = metrics + best_commands = moves + tool_result = { + "metrics": metrics, + "improved": _is_better(metrics, best_metrics) + or metrics == best_metrics, + } + except RuntimeError as e: + tool_result = {"error": str(e)} + log(f" [error] {e}\n") + else: + tool_result = {"error": f"Unknown tool: {name}"} + log(f" [tool] unknown: {name}\n") + + messages.append( + { + "role": "tool", + "tool_call_id": tc["id"], + "content": json.dumps(tool_result), + } + ) + + usage = stats.counters + log( + f"\n=== Usage ===\n" + f" tokens: {usage.tokens} (in={usage.input_tokens} out={usage.output_tokens})\n" + f" cost: ${usage.cost:.4f}\n" + ) return { "metrics": best_metrics, "commands": best_commands, "rounds": round_num, "baseline_metrics": baseline, + "usage": usage, } From dd4e3b27bb0f99b7b3380030a305bbaf5590b7db Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 01:39:35 +0100 Subject: [PATCH 20/49] refac Signed-off-by: Ivan Butygin --- conductor/llm.py | 234 +++++++++++++++++++++++++++-------------------- 1 file changed, 135 insertions(+), 99 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index 2e17a4236..a73476d8d 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -180,7 +180,102 @@ def _noop_log(_msg: str) -> None: ] -def _chat( +_TRANSIENT_ERRORS = ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + requests.exceptions.ConnectTimeout, +) + + +def _stream_request( + payload: dict[str, Any], + on_token: Callable[[str], None] | None, + on_thinking: Callable[[str], None] | None, +) -> Message: + """Execute a streaming chat request and assemble the response message.""" + resp = requests.post( + f"{BASE_URL}/chat/completions", + headers={"Authorization": f"Bearer {API_KEY}"}, + json=payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + if resp.status_code >= 400: + raise requests.HTTPError( + f"{resp.status_code} {resp.reason}: {resp.text}", + response=resp, + ) + + content_chunks: list[str] = [] + reasoning_chunks: list[str] = [] + tool_calls_by_index: dict[int, dict[str, Any]] = {} + usage: dict[str, Any] | None = None + + for line in resp.iter_lines(): + if not line or not line.startswith(b"data: "): + continue + data = line[6:] + if data == b"[DONE]": + break + chunk = json.loads(data) + if "usage" in chunk: + usage = chunk["usage"] + choices = chunk.get("choices") + if not choices: + continue + delta = choices[0].get("delta", {}) + + # Reasoning tokens (two OpenRouter formats). + reasoning_texts: list[str] = [] + for detail in delta.get("reasoning_details", []): + text = detail.get("text", "") + if text: + reasoning_texts.append(text) + rc = delta.get("reasoning_content", "") + if rc: + reasoning_texts.append(rc) + for text in reasoning_texts: + if on_thinking is not None: + on_thinking(text) + reasoning_chunks.append(text) + + # Content tokens. + token = delta.get("content", "") + if token: + if on_token is not None: + on_token(token) + content_chunks.append(token) + + # Tool call deltas. + for tc_delta in delta.get("tool_calls", []): + idx = tc_delta["index"] + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": tc_delta.get("id", ""), + "type": "function", + "function": {"name": "", "arguments": ""}, + } + tc = tool_calls_by_index[idx] + func_delta = tc_delta.get("function", {}) + if func_delta.get("name"): + tc["function"]["name"] += func_delta["name"] + if func_delta.get("arguments"): + tc["function"]["arguments"] += func_delta["arguments"] + + result: Message = {"role": "assistant", "content": "".join(content_chunks)} + if reasoning_chunks: + result["reasoning"] = "".join(reasoning_chunks) + if tool_calls_by_index: + result["tool_calls"] = [ + tool_calls_by_index[i] for i in sorted(tool_calls_by_index) + ] + if usage is not None: + result["usage"] = usage + return result + + +def chat( messages: list[Message], model: str, temperature: float = 0.7, @@ -189,7 +284,11 @@ def _chat( tools: list[dict] | None = None, log: Callable[[str], None] = _default_log, ) -> Message: - """Send a streaming chat completion request. Returns the full response message.""" + """Send a chat completion request. Returns the full response message dict. + + Handles payload construction, retries on transient errors, usage + recording, and streaming log output. + """ if not API_KEY: raise RuntimeError( "OPENROUTER_API_KEY not set. Export it before running the LLM loop." @@ -206,18 +305,40 @@ def _chat( if reasoning_effort is not None: payload["reasoning"] = {"enabled": True, "effort": reasoning_effort} if tools: - payload["tools"] = tools + payload["tools"] = [dict(t) for t in tools] + + # Streaming callbacks that manage [thinking]/[/thinking] delimiters. + in_reasoning = False + + def on_thinking(text: str) -> None: + nonlocal in_reasoning + if not in_reasoning: + log("\n [thinking] ") + in_reasoning = True + log(text) + + def on_token(text: str) -> None: + nonlocal in_reasoning + if in_reasoning: + log("\n [/thinking]\n") + in_reasoning = False for attempt in range(_MAX_RETRIES): try: - resp = requests.post( - f"{BASE_URL}/chat/completions", - headers={"Authorization": f"Bearer {API_KEY}"}, - json=payload, - stream=True, - timeout=_REQUEST_TIMEOUT, - ) - if resp.status_code >= 500: + result = _stream_request(payload, on_token, on_thinking) + if in_reasoning: + log("\n [/thinking]\n") + + if "usage" in result: + _record_usage(model, result["usage"]) + pt = result["usage"].get("prompt_tokens", "?") + ct = result["usage"].get("completion_tokens", "?") + log(f"\n [tokens] prompt={pt} completion={ct}\n") + return result + + except requests.HTTPError as exc: + resp = exc.response + if resp is not None and resp.status_code >= 500: if attempt < _MAX_RETRIES - 1: wait = _RETRY_BACKOFF * (attempt + 1) log( @@ -225,93 +346,8 @@ def _chat( ) time.sleep(wait) continue - if resp.status_code >= 400: - raise RuntimeError( - f"OpenRouter API error {resp.status_code}: {resp.text}" - ) - - # Stream and accumulate content, reasoning, and tool calls. - content_chunks: list[str] = [] - reasoning_chunks: list[str] = [] - tool_calls_by_index: dict[int, dict[str, Any]] = {} - usage = None - in_reasoning = False - - for line in resp.iter_lines(): - if not line or not line.startswith(b"data: "): - continue - data = line[6:] - if data == b"[DONE]": - break - chunk = json.loads(data) - if "usage" in chunk: - usage = chunk["usage"] - choices = chunk.get("choices") - if not choices: - continue - delta = choices[0].get("delta", {}) - - # Reasoning tokens (two OpenRouter formats). - for detail in delta.get("reasoning_details", []): - text = detail.get("text", "") - if text: - if not in_reasoning: - log("\n [thinking] ") - in_reasoning = True - log(text) - reasoning_chunks.append(text) - rc = delta.get("reasoning_content", "") - if rc: - if not in_reasoning: - log("\n [thinking] ") - in_reasoning = True - log(rc) - reasoning_chunks.append(rc) - - # Content tokens. - token = delta.get("content", "") - if token: - if in_reasoning: - log("\n [/thinking]\n") - in_reasoning = False - content_chunks.append(token) - - # Tool call deltas. - for tc_delta in delta.get("tool_calls", []): - idx = tc_delta["index"] - if idx not in tool_calls_by_index: - tool_calls_by_index[idx] = { - "id": tc_delta.get("id", ""), - "type": "function", - "function": {"name": "", "arguments": ""}, - } - tc = tool_calls_by_index[idx] - func_delta = tc_delta.get("function", {}) - if func_delta.get("name"): - tc["function"]["name"] += func_delta["name"] - if func_delta.get("arguments"): - tc["function"]["arguments"] += func_delta["arguments"] - - if in_reasoning: - log("\n [/thinking]\n") - - result: Message = {"role": "assistant", "content": "".join(content_chunks)} - if tool_calls_by_index: - result["tool_calls"] = [ - tool_calls_by_index[i] for i in sorted(tool_calls_by_index) - ] - if usage: - _record_usage(model, usage) - pt = usage.get("prompt_tokens", "?") - ct = usage.get("completion_tokens", "?") - log(f"\n [tokens] prompt={pt} completion={ct}\n") - return result - - except ( - requests.exceptions.ConnectionError, - requests.exceptions.ReadTimeout, - requests.exceptions.ChunkedEncodingError, - ): + raise + except _TRANSIENT_ERRORS: if attempt < _MAX_RETRIES - 1: wait = _RETRY_BACKOFF * (attempt + 1) log(f"\n [retry] connection error, waiting {wait:.0f}s...\n") @@ -390,7 +426,7 @@ def run_scheduling_loop( for round_num in range(1, max_rounds + 1): log(f"\n--- Round {round_num}/{max_rounds} ---\n") - response = _chat( + response = chat( messages, model=model, temperature=temperature, From a2ff94b2f6e15fedcd638409bed4509a2d4c3fa5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 01:41:55 +0100 Subject: [PATCH 21/49] _with_retry Signed-off-by: Ivan Butygin --- conductor/llm.py | 72 +++++++++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index a73476d8d..b91b1cf02 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -188,6 +188,37 @@ def _noop_log(_msg: str) -> None: ) +def _with_retry( + func: Callable[..., Any], + log: Callable[[str], None], + *args: Any, + **kwargs: Any, +) -> Any: + """Call *func* with retries on transient network errors and 5xx.""" + for attempt in range(_MAX_RETRIES): + try: + return func(*args, **kwargs) + except requests.HTTPError as exc: + resp = exc.response + if resp is not None and resp.status_code >= 500: + if attempt < _MAX_RETRIES - 1: + wait = _RETRY_BACKOFF * (attempt + 1) + log( + f"\n [retry] server {resp.status_code}, waiting {wait:.0f}s...\n" + ) + time.sleep(wait) + continue + raise + except _TRANSIENT_ERRORS: + if attempt < _MAX_RETRIES - 1: + wait = _RETRY_BACKOFF * (attempt + 1) + log(f"\n [retry] connection error, waiting {wait:.0f}s...\n") + time.sleep(wait) + else: + raise + raise RuntimeError("Unreachable") + + def _stream_request( payload: dict[str, Any], on_token: Callable[[str], None] | None, @@ -323,39 +354,16 @@ def on_token(text: str) -> None: log("\n [/thinking]\n") in_reasoning = False - for attempt in range(_MAX_RETRIES): - try: - result = _stream_request(payload, on_token, on_thinking) - if in_reasoning: - log("\n [/thinking]\n") - - if "usage" in result: - _record_usage(model, result["usage"]) - pt = result["usage"].get("prompt_tokens", "?") - ct = result["usage"].get("completion_tokens", "?") - log(f"\n [tokens] prompt={pt} completion={ct}\n") - return result + result: Message = _with_retry(_stream_request, log, payload, on_token, on_thinking) + if in_reasoning: + log("\n [/thinking]\n") - except requests.HTTPError as exc: - resp = exc.response - if resp is not None and resp.status_code >= 500: - if attempt < _MAX_RETRIES - 1: - wait = _RETRY_BACKOFF * (attempt + 1) - log( - f"\n [retry] server {resp.status_code}, waiting {wait:.0f}s...\n" - ) - time.sleep(wait) - continue - raise - except _TRANSIENT_ERRORS: - if attempt < _MAX_RETRIES - 1: - wait = _RETRY_BACKOFF * (attempt + 1) - log(f"\n [retry] connection error, waiting {wait:.0f}s...\n") - time.sleep(wait) - else: - raise - - raise RuntimeError("Unreachable") + if "usage" in result: + _record_usage(model, result["usage"]) + pt = result["usage"].get("prompt_tokens", "?") + ct = result["usage"].get("completion_tokens", "?") + log(f"\n [tokens] prompt={pt} completion={ct}\n") + return result def format_initial_prompt( From 3efee7d5b25c0e89fd0494ed52bd7cc3712920d5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 01:47:27 +0100 Subject: [PATCH 22/49] dome tool Signed-off-by: Ivan Butygin --- conductor/llm.py | 54 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index b91b1cf02..bb1ce6cbe 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -148,7 +148,10 @@ def _noop_log(_msg: str) -> None: - Pinned ops (s_endpgm, s_barrier, condition) cannot be moved. Work incrementally: try 1-3 moves per tool call, read the resulting metrics, \ -then decide your next moves. You can call the tool multiple times.\ +then decide your next moves. You can call the tool multiple times. + +When you are satisfied with the schedule or have no more ideas, call the `done` \ +tool to finish.\ """ TOOLS = [ @@ -176,7 +179,27 @@ def _noop_log(_msg: str) -> None: "required": ["moves"], }, }, - } + }, + { + "type": "function", + "function": { + "name": "done", + "description": ( + "Call this when you are finished scheduling. " + "Provide a short summary of what you tried." + ), + "parameters": { + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "Brief summary of scheduling attempts.", + } + }, + "required": ["summary"], + }, + }, + }, ] @@ -450,10 +473,23 @@ def run_scheduling_loop( tool_calls = response.get("tool_calls") if not tool_calls: - log(" No tool call, model is done.\n") - break + # Model should always call a tool (evaluate_moves or done). + # If it didn't, nudge it to use the proper tool interface. + log(" [retry] No tool call, nudging model...\n") + messages.append( + { + "role": "user", + "content": ( + "You must use the evaluate_moves tool to test " + "scheduling ideas, or call done when finished. " + "Do not write tool calls as text." + ), + } + ) + continue # Process each tool call. + done = False for tc in tool_calls: name = tc["function"]["name"] try: @@ -462,7 +498,12 @@ def run_scheduling_loop( tool_result = {"error": "Malformed JSON arguments."} log(f" [tool] {name}: malformed args\n") else: - if name == "evaluate_moves": + if name == "done": + summary = args.get("summary", "") + log(f" [done] {summary}\n") + tool_result = {"status": "ok"} + done = True + elif name == "evaluate_moves": moves = args.get("moves", []) log(f" [tool] evaluate_moves({moves})\n") try: @@ -492,6 +533,9 @@ def run_scheduling_loop( } ) + if done: + break + usage = stats.counters log( f"\n=== Usage ===\n" From 7759d6b038f193d3fb9c8f594e82675f5f0b626f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 01:51:31 +0100 Subject: [PATCH 23/49] abstract tools Signed-off-by: Ivan Butygin --- conductor/__init__.py | 4 ++ conductor/llm.py | 159 +++++++++++++++++++----------------------- conductor/tools.py | 87 +++++++++++++++++++++++ 3 files changed, 163 insertions(+), 87 deletions(-) create mode 100644 conductor/tools.py diff --git a/conductor/__init__.py b/conductor/__init__.py index 48d619d8b..7206ca95a 100644 --- a/conductor/__init__.py +++ b/conductor/__init__.py @@ -16,6 +16,7 @@ capture_kernel_mlir, ) from conductor.llm import run_scheduling_loop, Stats, Counters +from conductor.tools import Param, ToolDef, ToolRegistry __all__ = [ "Conductor", @@ -29,4 +30,7 @@ "run_scheduling_loop", "Stats", "Counters", + "Param", + "ToolDef", + "ToolRegistry", ] diff --git a/conductor/llm.py b/conductor/llm.py index bb1ce6cbe..0c2f2f5bb 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -13,6 +13,8 @@ import requests +from conductor.tools import Param, ToolRegistry + API_KEY: str = os.environ.get("OPENROUTER_API_KEY", "") BASE_URL: str = "https://openrouter.ai/api/v1" DEFAULT_MODEL: str = "deepseek/deepseek-v3.2" @@ -154,54 +156,6 @@ def _noop_log(_msg: str) -> None: tool to finish.\ """ -TOOLS = [ - { - "type": "function", - "function": { - "name": "evaluate_moves", - "description": ( - "Apply a list of move/swap commands to the tagged IR, " - "compile through the post-scheduling pipeline, and return " - "assembly metrics. Commands are applied in order." - ), - "parameters": { - "type": "object", - "properties": { - "moves": { - "type": "array", - "items": {"type": "string"}, - "description": ( - "List of move commands, e.g. " - '["move tag_A after tag_B", "swap tag_C tag_D"].' - ), - } - }, - "required": ["moves"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "done", - "description": ( - "Call this when you are finished scheduling. " - "Provide a short summary of what you tried." - ), - "parameters": { - "type": "object", - "properties": { - "summary": { - "type": "string", - "description": "Brief summary of scheduling attempts.", - } - }, - "required": ["summary"], - }, - }, - }, -] - _TRANSIENT_ERRORS = ( requests.exceptions.ChunkedEncodingError, @@ -442,6 +396,69 @@ def run_scheduling_loop( best_metrics = dict(baseline) best_commands: list[str] = [] + finished = False + + # Build tool registry with closures over loop state. + registry = ToolRegistry() + + def _evaluate_moves(moves: list[str]) -> str: + nonlocal best_metrics, best_commands + log(f" [tool] evaluate_moves({moves})\n") + try: + metrics = conductor.evaluate(moves) + except RuntimeError as e: + log(f" [error] {e}\n") + return json.dumps({"error": str(e)}) + log(f" [result] {metrics}\n") + if _is_better(metrics, best_metrics): + log(" Improvement!\n") + best_metrics = metrics + best_commands = moves + return json.dumps( + { + "metrics": metrics, + "improved": _is_better(metrics, best_metrics) + or metrics == best_metrics, + } + ) + + def _done(summary: str) -> str: + nonlocal finished + log(f" [done] {summary}\n") + finished = True + return json.dumps({"status": "ok"}) + + registry.add( + name="evaluate_moves", + description=( + "Apply a list of move/swap commands to the tagged IR, " + "compile through the post-scheduling pipeline, and return " + "assembly metrics. Commands are applied in order." + ), + params=[ + Param( + name="moves", + description=( + "List of move commands, e.g. " + '["move tag_A after tag_B", "swap tag_C tag_D"].' + ), + type="array", + items={"type": "string"}, + ), + ], + func=_evaluate_moves, + ) + registry.add( + name="done", + description=( + "Call this when you are finished scheduling. " + "Provide a short summary of what you tried." + ), + params=[ + Param(name="summary", description="Brief summary of scheduling attempts."), + ], + func=_done, + ) messages: list[Message] = [ {"role": "system", "content": SYSTEM_PROMPT}, @@ -462,7 +479,7 @@ def run_scheduling_loop( model=model, temperature=temperature, reasoning_effort=reasoning_effort, - tools=TOOLS, + tools=registry.definitions(), log=log, ) messages.append(response) @@ -488,52 +505,20 @@ def run_scheduling_loop( ) continue - # Process each tool call. - done = False for tc in tool_calls: - name = tc["function"]["name"] - try: - args = json.loads(tc["function"]["arguments"]) - except json.JSONDecodeError: - tool_result = {"error": "Malformed JSON arguments."} - log(f" [tool] {name}: malformed args\n") - else: - if name == "done": - summary = args.get("summary", "") - log(f" [done] {summary}\n") - tool_result = {"status": "ok"} - done = True - elif name == "evaluate_moves": - moves = args.get("moves", []) - log(f" [tool] evaluate_moves({moves})\n") - try: - metrics = conductor.evaluate(moves) - log(f" [result] {metrics}\n") - if _is_better(metrics, best_metrics): - log(" Improvement!\n") - best_metrics = metrics - best_commands = moves - tool_result = { - "metrics": metrics, - "improved": _is_better(metrics, best_metrics) - or metrics == best_metrics, - } - except RuntimeError as e: - tool_result = {"error": str(e)} - log(f" [error] {e}\n") - else: - tool_result = {"error": f"Unknown tool: {name}"} - log(f" [tool] unknown: {name}\n") - + result = registry.execute( + tc["function"]["name"], + tc["function"]["arguments"], + ) messages.append( { "role": "tool", "tool_call_id": tc["id"], - "content": json.dumps(tool_result), + "content": result, } ) - if done: + if finished: break usage = stats.counters diff --git a/conductor/tools.py b/conductor/tools.py new file mode 100644 index 000000000..b0f87b8c0 --- /dev/null +++ b/conductor/tools.py @@ -0,0 +1,87 @@ +"""Tool registry for LLM function calling.""" + +import json +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Param: + """Single tool parameter specification.""" + + name: str + description: str + type: str = "string" + required: bool = True + items: dict[str, str] | None = None + + +@dataclass +class ToolDef: + """Tool definition with schema and handler.""" + + name: str + description: str + params: list[Param] = field(default_factory=list) + func: Callable[..., str] | None = None + + def to_api(self) -> dict[str, Any]: + """Convert to OpenAI function-calling format.""" + properties: dict[str, Any] = {} + for p in self.params: + prop: dict[str, Any] = {"type": p.type, "description": p.description} + if p.items is not None: + prop["items"] = p.items + properties[p.name] = prop + required = [p.name for p in self.params if p.required] + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } + + +class ToolRegistry: + """Central registry for tool definitions and dispatch.""" + + def __init__(self) -> None: + self._tools: dict[str, ToolDef] = {} + + def add( + self, + name: str, + description: str, + params: list[Param], + func: Callable[..., str], + ) -> None: + """Register a tool with its definition and handler.""" + self._tools[name] = ToolDef( + name=name, + description=description, + params=params, + func=func, + ) + + def definitions(self) -> list[dict[str, Any]]: + """Return all tool definitions in OpenAI API format.""" + return [tool.to_api() for tool in self._tools.values()] + + def execute(self, name: str, arguments: str) -> str: + """Execute a tool by name with JSON-encoded arguments.""" + name = name.strip() + tool = self._tools.get(name) + if tool is None or tool.func is None: + return json.dumps({"error": f"Unknown tool: {name}"}) + try: + kwargs: dict[str, Any] = json.loads(arguments) + return tool.func(**kwargs) + except Exception as e: + return json.dumps({"error": str(e)}) From f49d4aa36e741b510858b0f9ab7c7417cfb60de7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 02:06:42 +0100 Subject: [PATCH 24/49] print the updated IR Signed-off-by: Ivan Butygin --- conductor/conductor.py | 7 ++++++- conductor/llm.py | 10 ++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/conductor/conductor.py b/conductor/conductor.py index 84380bb96..219098522 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -144,10 +144,15 @@ def evaluate(self, commands: list) -> dict: This is the main entry point for a search algorithm. """ + _, metrics = self.evaluate_with_ir(commands) + return metrics + + def evaluate_with_ir(self, commands: list) -> tuple[str, dict]: + """Like evaluate, but also returns the reordered tagged IR.""" tagged = self.tag() reordered = self.apply_moves(tagged, commands) asm = self.compile_to_asm(reordered) - return self.get_metrics(asm) + return reordered, self.get_metrics(asm) def baseline(self) -> dict: """Evaluate with no moves (identity schedule). Caches result.""" diff --git a/conductor/llm.py b/conductor/llm.py index 0c2f2f5bb..ca90e60ff 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -405,20 +405,22 @@ def _evaluate_moves(moves: list[str]) -> str: nonlocal best_metrics, best_commands log(f" [tool] evaluate_moves({moves})\n") try: - metrics = conductor.evaluate(moves) + reordered_ir, metrics = conductor.evaluate_with_ir(moves) except RuntimeError as e: log(f" [error] {e}\n") return json.dumps({"error": str(e)}) log(f" [result] {metrics}\n") - if _is_better(metrics, best_metrics): + log(f" --- Updated IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") + improved = _is_better(metrics, best_metrics) + if improved: log(" Improvement!\n") best_metrics = metrics best_commands = moves return json.dumps( { "metrics": metrics, - "improved": _is_better(metrics, best_metrics) - or metrics == best_metrics, + "improved": improved or metrics == best_metrics, + "updated_ir": reordered_ir.strip(), } ) From 6c6fa57b0305298793a605cc2bd6fafef8f5b798 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 02:07:50 +0100 Subject: [PATCH 25/49] print ir Signed-off-by: Ivan Butygin --- conductor/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conductor/llm.py b/conductor/llm.py index ca90e60ff..d3888ad33 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -392,7 +392,7 @@ def run_scheduling_loop( log(f" baseline: {baseline}\n") tagged_ir = conductor.tag() - log(f" tagged IR: {len(tagged_ir)} chars\n") + log(f" --- Tagged IR ---\n{tagged_ir.strip()}\n --- End IR ---\n") best_metrics = dict(baseline) best_commands: list[str] = [] From 6871a7f51c4d69be6709e12d66b27f2dd034f570 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 02:11:02 +0100 Subject: [PATCH 26/49] prompt tweak Signed-off-by: Ivan Butygin --- conductor/llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conductor/llm.py b/conductor/llm.py index d3888ad33..49856ccd6 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -146,7 +146,8 @@ def _noop_log(_msg: str) -> None: The tool will apply the moves, compile, and return metrics. Constraints: -- Moves that break SSA dominance will be rejected. +- All moves must stay within the same basic block. Never move across blocks. +- Moves must preserve SSA dominance: a value must be defined before all its uses. - Pinned ops (s_endpgm, s_barrier, condition) cannot be moved. Work incrementally: try 1-3 moves per tool call, read the resulting metrics, \ From 716966eaecf5cf57f138287c56f09bdf44735e2d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 09:53:58 +0100 Subject: [PATCH 27/49] error logging Signed-off-by: Ivan Butygin --- conductor/llm.py | 10 +++++++++- .../tools/waveasm-conductor/waveasm-conductor.cpp | 6 ++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/conductor/llm.py b/conductor/llm.py index 49856ccd6..09365022b 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -406,10 +406,18 @@ def _evaluate_moves(moves: list[str]) -> str: nonlocal best_metrics, best_commands log(f" [tool] evaluate_moves({moves})\n") try: - reordered_ir, metrics = conductor.evaluate_with_ir(moves) + tagged = conductor.tag() + reordered_ir = conductor.apply_moves(tagged, moves) except RuntimeError as e: log(f" [error] {e}\n") return json.dumps({"error": str(e)}) + try: + asm = conductor.compile_to_asm(reordered_ir) + metrics = conductor.get_metrics(asm) + except RuntimeError as e: + log(f" [error] {e}\n") + log(f" --- Faulty IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") + return json.dumps({"error": str(e)}) log(f" [result] {metrics}\n") log(f" --- Updated IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") improved = _is_better(metrics, best_metrics) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp index 1c6411ed3..2cb07f511 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -107,12 +107,18 @@ int main(int argc, char **argv) { if (!result.success) { llvm::errs() << "conductor: command " << result.failedCommand << ": " << result.error << "\n"; + llvm::errs() << "--- IR after partial moves ---\n"; + module->print(llvm::errs()); + llvm::errs() << "--- end IR ---\n"; return 1; } // Verify the module after moves (catches broken dominance, etc.). if (failed(mlir::verify(*module))) { llvm::errs() << "conductor: verification failed after applying moves\n"; + llvm::errs() << "--- IR at verification failure ---\n"; + module->print(llvm::errs()); + llvm::errs() << "--- end IR ---\n"; return 1; } From 7c2babd6772579241206b640465e7bce9dab5e52 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 10:09:50 +0100 Subject: [PATCH 28/49] scheduling --- conductor/extract_ir.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/conductor/extract_ir.py b/conductor/extract_ir.py index 59f9d60a1..975f9d3fa 100644 --- a/conductor/extract_ir.py +++ b/conductor/extract_ir.py @@ -83,6 +83,8 @@ def capture_kernel_mlir() -> tuple: SHARED_ADDRESS_SPACE, ) from wave_lang.kernel.wave.compile import WaveCompileOptions + from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType + from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params from wave_lang.kernel.wave.utils.run_utils import set_default_run_config from wave_lang.kernel._support.indexing import IndexingContext from wave_lang.kernel.wave.compile import _trace_launchable_and_get_kernel_signature @@ -136,18 +138,23 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: m, n, k = 256, 256, 256 block_k = 16 + subs = { + M: m, + N: n, + K: k, + BLOCK_M: block_m, + BLOCK_N: block_n, + BLOCK_K: block_k, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + } + subs.update(get_default_scheduling_params()) + options = WaveCompileOptions( - subs={ - M: m, - N: n, - K: k, - BLOCK_M: block_m, - BLOCK_N: block_n, - BLOCK_K: block_k, - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, - }, + subs=subs, canonicalize=True, + schedule=SchedulingType.PREFETCH, + use_scheduling_barriers=True, backend="asm", wave_runtime=True, compile_to_mlir=False, From b38b793d977229cb7e2ffb3601250948eab6a5e6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 10:14:43 +0100 Subject: [PATCH 29/49] mxfp kernel --- conductor/__init__.py | 2 ++ conductor/conductor.py | 16 ++++++++++-- conductor/extract_ir.py | 56 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 70 insertions(+), 4 deletions(-) diff --git a/conductor/__init__.py b/conductor/__init__.py index 7206ca95a..5510d001f 100644 --- a/conductor/__init__.py +++ b/conductor/__init__.py @@ -14,6 +14,7 @@ run_full_pipeline, count_asm_metrics, capture_kernel_mlir, + capture_mxfp4_kernel_mlir, ) from conductor.llm import run_scheduling_loop, Stats, Counters from conductor.tools import Param, ToolDef, ToolRegistry @@ -27,6 +28,7 @@ "run_full_pipeline", "count_asm_metrics", "capture_kernel_mlir", + "capture_mxfp4_kernel_mlir", "run_scheduling_loop", "Stats", "Counters", diff --git a/conductor/conductor.py b/conductor/conductor.py index 219098522..842d32a33 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -217,6 +217,13 @@ def main(): default="high", help="Reasoning effort for models that support it (default: high).", ) + parser.add_argument( + "--kernel", + type=str, + default="gemm", + choices=["gemm", "mxfp4"], + help="Kernel to capture (default: gemm).", + ) args = parser.parse_args() # Collect commands from both sources. @@ -230,8 +237,13 @@ def main(): if line.strip() and not line.strip().startswith("#") ) - print("Capturing kernel MLIR...", file=sys.stderr) - mlir_text, wg_size = capture_kernel_mlir() + from conductor.extract_ir import capture_mxfp4_kernel_mlir + + capture_fn = ( + capture_mxfp4_kernel_mlir if args.kernel == "mxfp4" else capture_kernel_mlir + ) + print(f"Capturing {args.kernel} kernel MLIR...", file=sys.stderr) + mlir_text, wg_size = capture_fn() print(f" workgroup_size: {wg_size}", file=sys.stderr) print("Running pre-scheduling pipeline...", file=sys.stderr) diff --git a/conductor/extract_ir.py b/conductor/extract_ir.py index 975f9d3fa..98aa9d49d 100644 --- a/conductor/extract_ir.py +++ b/conductor/extract_ir.py @@ -70,6 +70,48 @@ def get_target() -> str: return os.environ.get("WAVE_DEFAULT_ARCH", "gfx942") +def capture_mxfp4_kernel_mlir() -> tuple: + """Capture MLIR from a double-buffered 4-wave MXFP4 GEMM kernel. + + Returns (mlir_text, workgroup_size). + """ + from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm_preshuffle_b + from wave_lang.kernel.wave.schedules import get_mxfp4_asymmetric_schedule + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + + # Same config as test_dbuf_4wave_mxfp4_gemm_cpp_backend. + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape=(1024, 1024, 8192), + block_shape=(256, 256, 256), + wave_shape=(1, 4), + ) + schedule = get_mxfp4_asymmetric_schedule() + + options.backend = "asm" + options.wave_runtime = True + options.compile_to_mlir = False + options = set_default_run_config(options) + + # Reuse the test helper that handles opsel + canonicalize. + sys.path.insert( + 0, + str( + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "test" + / "e2e" + ), + ) + from waveasm_e2e import capture_wave_kernel_info + + info = capture_wave_kernel_info(options, gemm, schedule=schedule) + return info.mlir_text, info.workgroup_size + + def capture_kernel_mlir() -> tuple: """ Capture MLIR from a multi-wave GEMM kernel. @@ -338,10 +380,20 @@ def main(): action="store_true", help="Also run full pipeline and print baseline metrics.", ) + parser.add_argument( + "--kernel", + type=str, + default="gemm", + choices=["gemm", "mxfp4"], + help="Kernel to capture (default: gemm).", + ) args = parser.parse_args() - print("Capturing kernel MLIR...", file=sys.stderr) - mlir_text, wg_size = capture_kernel_mlir() + capture_fn = ( + capture_mxfp4_kernel_mlir if args.kernel == "mxfp4" else capture_kernel_mlir + ) + print(f"Capturing {args.kernel} kernel MLIR...", file=sys.stderr) + mlir_text, wg_size = capture_fn() print(f" workgroup_size: {wg_size}", file=sys.stderr) print(f" input MLIR: {len(mlir_text)} chars", file=sys.stderr) From fa2f3c6c3c017e4501e7bba627a5b216f35d0bf3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 10:17:02 +0100 Subject: [PATCH 30/49] fix rocm check Signed-off-by: Ivan Butygin --- .../wave/asm/wave_asm/test/e2e/waveasm_e2e.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/waveasm_e2e.py b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/waveasm_e2e.py index 035bf403b..92071487e 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/waveasm_e2e.py +++ b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/waveasm_e2e.py @@ -125,13 +125,6 @@ def get_waveasm_translate_path() -> Path: def get_amdclang_path() -> str: """Get path to amdclang++ for assembly compilation.""" - # Check ROCM_PATH first - rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") - amdclang = os.path.join(rocm_path, "bin", "amdclang++") - - if os.path.exists(amdclang): - return amdclang - # Try to find it in PATH try: result = subprocess.run(["which", "amdclang++"], capture_output=True, text=True) @@ -140,6 +133,13 @@ def get_amdclang_path() -> str: except Exception: pass + # Do not check ROCM_PATH first, it breaks with the Rock. + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + amdclang = os.path.join(rocm_path, "bin", "amdclang++") + + if os.path.exists(amdclang): + return amdclang + raise FileNotFoundError( "amdclang++ not found. Ensure ROCm is installed and in PATH." ) From ca3da4abb6562c104e5b7871a1463368dad7827e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 11:02:03 +0100 Subject: [PATCH 31/49] instructions Signed-off-by: Ivan Butygin --- conductor/SCHEDULING_GUIDE.md | 182 ++++++++++++++++++++++++++++++++++ conductor/llm.py | 25 +++-- 2 files changed, 201 insertions(+), 6 deletions(-) create mode 100644 conductor/SCHEDULING_GUIDE.md diff --git a/conductor/SCHEDULING_GUIDE.md b/conductor/SCHEDULING_GUIDE.md new file mode 100644 index 000000000..38c3f9d03 --- /dev/null +++ b/conductor/SCHEDULING_GUIDE.md @@ -0,0 +1,182 @@ +# CDNA4 Instruction Scheduling Guide + +Reference for the Conductor LLM scheduling loop. Based on the AMD Instinct +CDNA4 ISA Reference Guide (gfx950). + +## Architecture Overview + +- Wave size: 64 threads. +- Register files: 256 arch VGPRs (V0-V255) + 256 AccVGPRs/AGPRs (A0-A255) = 512 total per wave. +- SGPRs: 16-102 per wave (VCC occupies SGPR 106-107). +- LDS: 160 KiB per CU, 64 banks x 32-bit, 1280-byte allocation granularity. +- Allocation granularity: VGPRs in groups of 8, SGPRs in groups of 16. +- Issue model: one instruction per cycle per wavefront (latency hiding via interleaving wavefronts). + +## Instruction Latencies + +| Instruction class | Latency (cycles) | Counter | +|---|---|---| +| Global load (buffer_load, global_load) | ~100 | vmcnt | +| LDS read (ds_read) | ~20 | lgkmcnt | +| LDS write (ds_write) | ~20 | lgkmcnt | +| MFMA F16/BF16 16x16x16 | 16 | — | +| MFMA F16/BF16 32x32x8 | 32 | — | +| MFMA F32 16x16x4 | 32 | — | +| MFMA F32 32x32x2 | 64 | — | +| MFMA F8/F4 16x16x128 | 16 (FP4/FP6) or 32 (FP8) | — | +| MFMA F8/F4 32x32x64 | 32 (FP4/FP6) or 64 (FP8) | — | +| scaled_mfma (MXFP4) | Same as F8/F4 above | — | +| MFMA F64 16x16x4 | 64 | — | +| VALU (non-transcendental) | 1 | — | +| VALU transcendental (exp, log, rcp, rsq, sqrt, sin, cos) | 2 | — | +| SALU | 1 | — | +| Scalar memory read | variable | lgkmcnt | + +## Waitcnt Counters + +| Counter | Bits | Max outstanding | Tracked operations | +|---|---|---|---| +| vmcnt | 6 | 63 | Global/buffer loads and stores | +| lgkmcnt | 4 | 15 | LDS ops, scalar memory, sendmsg | + +- VMEM reads/writes return in issue order. +- Scalar memory reads can return **out of order** — only `lgkmcnt(0)` is safe for SMEM. +- FLAT instructions increment both vmcnt and lgkmcnt — only `s_waitcnt 0` is safe after FLAT. +- `s_endpgm` implicitly executes `s_waitcnt 0`. + +## MFMA Dependency Rules + +### Accumulator chains (SrcC = previous vDst, same opcode, same register range) + +**Zero software NOPs required.** The hardware provides 2 implicit wait cycles. +This is the intended use pattern for matrix accumulation loops — chain MFMAs +back-to-back on the same accumulator with no stall. + +### Cross-dependencies (reading MFMA output as SrcA/SrcB or in VALU) + +These require the full output latency to elapse: + +| Scenario | Wait (NOPs) | +|---|---| +| XDL write → SrcA/B of any MFMA | 5 / 8 / 12 / 20 | +| XDL write → VALU read/write (RAW+WAW) | 5 / 8 / 12 / 20 | +| XDL write → VMEM/LDS/FLAT overlap | 5 / 8 / 12 / 20 | +| SGEMM write → SrcA/B of any MFMA | 4 / 6 / 10 / 18 | +| DGEMM 16x16x4 write → SrcA/B or VALU | 19 | + +The multiple values correspond to different output register overlap depths. + +### Cross-type forwarding + +| Scenario | Wait (NOPs) | +|---|---| +| SGEMM write → XDL SrcC (overlapping) | **0** (XDL reads SrcC 2x faster) | +| XDL write → SGEMM SrcC (overlapping) | 3 | +| v_cmpx (writes EXEC) → any V_MFMA | **4** (no exec forwarding to matrix core) | + +## Software Hazards (Table 11) + +The hardware does **not** detect these. The compiler must insert s_nop or +independent instructions. + +| First instruction | Second instruction | NOPs required | +|---|---|---| +| VALU writes SGPR | VMEM reads that SGPR | **5** | +| VALU sets VCC or EXEC | VALU reads EXECZ/VCCZ as data | **5** | +| VALU writes EXEC | VALU DPP op | **5** | +| VALU writes SGPR/VCC | v_readlane/v_writelane lane select | **4** | +| VALU writes VCC | v_div_fmas | **4** | +| v_cmpx writes EXEC | v_readlane/v_readfirstlane/v_writelane | **4** | +| VALU writes SGPR/VCC | VALU reads SGPR as constant | **2** | +| v_cmpx writes EXEC | VALU reads EXEC as constant | **2** | +| S_SETREG MODE.vskip | Any vector op | **2** | +| Trans op (exp, log, rcp, ...) | Non-trans VALU consuming result | **1** | +| VALU writes VGPR | VALU DPP reads that VGPR | **1** | +| SALU writes M0 | LDS add-TID / buffer_store_LDS | **1** | +| Mixed VCC alias access | VALU reads VCC as constant | **1** | +| FLAT/BUFFER_STORE x3/x4 | Write VGPRs holding writedata | **1** | + +Note: `s_nop N` inserts N+1 idle cycles. So `s_nop 4` = 5 NOPs. + +## Occupancy and Register Pressure + +Waves per SIMD as a function of VGPR usage (512-slot pool, 8-register granularity): + +| VGPRs used | Waves/SIMD | Waves/CU (4 SIMDs) | +|---|---|---| +| ≤64 | 8 | 32 | +| 65-72 | 7 | 28 | +| 73-80 | 6 | 24 | +| 81-96 | 5 | 20 | +| 97-128 | 4 | 16 | +| 129-168 | 3 | 12 | +| 169-256 | 2 | 8 | + +AGPRs use an independent pool with the same formula. Final occupancy is +`min(vgpr_waves, agpr_waves, sgpr_waves, lds_waves)`. + +**Key breakpoints:** Dropping below 128, 96, 80, or 64 VGPRs each adds one +wave per SIMD. For MFMA-dominant kernels, 2-4 waves/SIMD is often optimal. + +## Scheduling Strategy + +### 1. Issue global loads early + +Global loads have ~100 cycle latency. Move them as far above their consumers +as possible. Every MFMA (16-64 cycles) or LDS op placed between the load and +its `s_waitcnt vmcnt` hides latency for free. + +### 2. Interleave LDS reads with MFMA + +LDS reads take ~20 cycles. A single MFMA 16x16x16 (16 cycles) nearly covers +one LDS read. Interleaving `ds_read → mfma → ds_read → mfma` hides LDS +latency much better than `ds_read, ds_read, ..., mfma, mfma, ...`. + +### 3. Fill MFMA pipeline bubbles + +Between two MFMAs with an accumulator dependency (B.SrcC = A.vDst), the +hardware stalls until A completes. Fill this gap with: +- Independent LDS reads/writes (for the next iteration). +- Independent VALU (address computation, data formatting). +- Independent global loads (prefetch for next tile). + +### 4. Minimize live ranges + +Moving a producer closer to its consumer (or deferring a def until just +before use) reduces the number of simultaneously live VGPRs, potentially +lowering peak register pressure and enabling higher occupancy. + +### 5. Double-buffering pattern + +The ideal software-pipelined loop body: +``` +barrier +ds_read current tile from LDS +global_load next tile (prefetch) +mfma on current tile (hides global_load latency) +s_waitcnt vmcnt(0) +barrier +ds_write next tile to LDS +``` + +### 6. Metric priorities (lexicographic) + +1. **peak_vgpr** — lower = higher occupancy. +2. **s_waitcnt** — fewer waits = better latency hiding. +3. **s_nop** — fewer NOPs = fewer pipeline hazards. +4. **total_instructions** — fewer = less instruction cache pressure. + +## LDS Details + +- 160 KiB per CU, 64 banks, each 32-bit wide. +- Best case latency (no bank conflicts): 2 cycles. +- Worst case (64-way bank conflict): 64 cycles. +- A wavefront (64 threads) is dispatched over 4 sub-cycles (16 threads each). +- DS_READ2/DS_WRITE2 (64-bit extended) double per-op bandwidth. + +## Barrier Semantics + +- `s_barrier` synchronizes all wavefronts in a workgroup. +- It does NOT implicitly wait for memory — issue `s_waitcnt` before `s_barrier` + if protecting memory operations. +- Treat barriers as hard scheduling boundaries. Never move ops across barriers. diff --git a/conductor/llm.py b/conductor/llm.py index 09365022b..8f4f40f85 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -130,12 +130,25 @@ def _noop_log(_msg: str) -> None: SYSTEM_PROMPT = """\ -You are an expert GPU instruction scheduler for AMD CDNA/RDNA architectures. +You are an expert GPU instruction scheduler for AMD CDNA architectures. You will receive WaveASM MLIR IR with tagged instructions (loc("tag_name")). Your job is to reorder instructions to hide memory latency and reduce register pressure. -Key latencies: global loads ~100 cycles, LDS loads ~20 cycles, MFMA 16x16 ~32 cycles. +Latencies (cycles): + global_load/buffer_load: ~100 LDS (ds_read/ds_write): ~20 + MFMA F16 16x16x16: 16 MFMA F16 32x32x8: 32 + MFMA F8/F4 16x16x128: 16-32 MFMA F8/F4 32x32x64: 32-64 + scaled_mfma (MXFP4): same as above + VALU: 1 (transcendentals: 2) SALU: 1 + +Scheduling strategy: +- Issue global loads as early as possible, defer s_waitcnt to just before use. +- Interleave LDS reads with MFMA: ~20 cycles of MFMA hides one LDS read. +- Fill cycles between dependent MFMAs with independent loads or VALU. +- MFMA accumulator chains (same opcode, SrcC=prev vDst) need 0 extra NOPs. +- Fewer s_waitcnt and s_nop in the final assembly = better schedule. +- Lower peak VGPRs = higher occupancy (key breakpoints: 128, 96, 80, 64). You have an `evaluate_moves` tool. Call it with a list of move command strings. Each command is one of: @@ -143,7 +156,7 @@ def _noop_log(_msg: str) -> None: "move TAG_A before TAG_B" "swap TAG_A TAG_B" -The tool will apply the moves, compile, and return metrics. +The tool will apply the moves, compile, and return metrics + updated IR. Constraints: - All moves must stay within the same basic block. Never move across blocks. @@ -351,8 +364,7 @@ def format_initial_prompt( ) -> str: """Format the initial user message with IR and baseline metrics.""" parts = [ - f"TARGET: {target} (wave64, 512 vgpr, 106 sgpr, 512 agpr)", - "LATENCY: vmem=100, lds=20, mfma_16x16=32, mfma_32x32=64", + f"TARGET: {target} (wave64, 256 vgpr + 256 agpr, 102 sgpr)", "", "--- Tagged IR ---", tagged_ir.strip(), @@ -364,7 +376,8 @@ def format_initial_prompt( parts.extend( [ "", - "GOAL: Minimize register pressure and hide memory latency.", + "GOAL: Reduce s_waitcnt and s_nop count (better latency hiding),", + "then reduce peak VGPRs (higher occupancy).", "Use the evaluate_moves tool to try reorderings.", ] ) From c8024341e22f749c35080feec026acf939881bb4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 11:20:16 +0100 Subject: [PATCH 32/49] show diff Signed-off-by: Ivan Butygin --- conductor/llm.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/conductor/llm.py b/conductor/llm.py index 8f4f40f85..21fc4d339 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -1,5 +1,6 @@ """OpenRouter LLM client and iterative scheduling loop for Conductor.""" +import difflib import json import os import sys @@ -432,7 +433,9 @@ def _evaluate_moves(moves: list[str]) -> str: log(f" --- Faulty IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") return json.dumps({"error": str(e)}) log(f" [result] {metrics}\n") - log(f" --- Updated IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") + ir_diff = _context_diff(tagged, reordered_ir) + if ir_diff: + log(f" --- IR diff ---\n{ir_diff} --- End diff ---\n") improved = _is_better(metrics, best_metrics) if improved: log(" Improvement!\n") @@ -561,6 +564,14 @@ def _done(summary: str) -> str: } +def _context_diff(before: str, after: str, n: int = 10) -> str: + """Return a unified diff of changed lines with ±n lines of context.""" + a = before.splitlines(keepends=True) + b = after.splitlines(keepends=True) + diff = difflib.unified_diff(a, b, fromfile="before", tofile="after", n=n) + return "".join(diff) + + def _is_better(new: dict, old: dict) -> bool: """Compare metrics: lower VGPRs > fewer waitcnts > fewer nops > fewer instructions.""" for key in ("peak_vgpr", "s_waitcnt", "s_nop", "total_instructions"): From 5994e92384150b3c4de8d37d13dbb9bf18707713 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 11:24:01 +0100 Subject: [PATCH 33/49] move doc Signed-off-by: Ivan Butygin --- .../wave_asm/include/waveasm => conductor}/CONDUCTOR_DESIGN.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {wave_lang/kernel/wave/asm/wave_asm/include/waveasm => conductor}/CONDUCTOR_DESIGN.md (100%) diff --git a/wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md b/conductor/CONDUCTOR_DESIGN.md similarity index 100% rename from wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md rename to conductor/CONDUCTOR_DESIGN.md From 5f96075c92ce96f96fd3cc7e2208bf84f839c25f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 11:57:07 +0100 Subject: [PATCH 34/49] beter kernel Signed-off-by: Ivan Butygin --- conductor/extract_ir.py | 145 +++++++++------------------------------- 1 file changed, 30 insertions(+), 115 deletions(-) diff --git a/conductor/extract_ir.py b/conductor/extract_ir.py index 98aa9d49d..53c073f30 100644 --- a/conductor/extract_ir.py +++ b/conductor/extract_ir.py @@ -113,130 +113,45 @@ def capture_mxfp4_kernel_mlir() -> tuple: def capture_kernel_mlir() -> tuple: - """ - Capture MLIR from a multi-wave GEMM kernel. + """Capture MLIR from a manually-scheduled 8-wave GEMM kernel. + + Uses get_tagged_gemm + get_two_pp_cluster_schedule for a 2-stage + pipelined prefetch with cluster reordering and wave staggering. Returns (mlir_text, workgroup_size). """ - import wave_lang.kernel.lang as tkl - import wave_lang.kernel.wave as tkw - from wave_lang.kernel.lang.global_symbols import ( - GLOBAL_ADDRESS_SPACE, - SHARED_ADDRESS_SPACE, - ) - from wave_lang.kernel.wave.compile import WaveCompileOptions - from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType - from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params + from wave_lang.kernel.wave.schedules import get_two_pp_cluster_schedule + from wave_lang.kernel.wave.schedules.gemm_two_pp_cluster import get_tagged_gemm from wave_lang.kernel.wave.utils.run_utils import set_default_run_config - from wave_lang.kernel._support.indexing import IndexingContext - from wave_lang.kernel.wave.compile import _trace_launchable_and_get_kernel_signature - from wave_lang.support.ir_imports import Context, Module, func_d - from wave_lang.kernel.wave.asm.mlir_analysis import ( - walk_ops_recursively, - should_skip_function, - ) - from wave_lang.kernel.wave.utils.compile_utils import canonicalize_module - - M = tkl.sym.M - N = tkl.sym.N - K = tkl.sym.K - BLOCK_M = tkl.sym.BLOCK_M - BLOCK_N = tkl.sym.BLOCK_N - BLOCK_K = tkl.sym.BLOCK_K - ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE - ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 - - # 4-wave config: 2x2 waves, 32x32 wave tiles. - block_m, block_n, wave_m, wave_n = 64, 64, 32, 32 - wave_size = 64 - mma_type = tkw.MMAType.F32_16x16x16_F16 - - constraints = [ - tkw.WorkgroupConstraint(M, BLOCK_M, 0), - tkw.WorkgroupConstraint(N, BLOCK_N, 1), - tkw.TilingConstraint(K, BLOCK_K), - tkw.WaveConstraint(M, wave_m), - tkw.WaveConstraint(N, wave_n), - tkw.HardwareConstraint(threads_per_wave=wave_size, mma_type=mma_type), - ] - @tkw.wave(constraints) - def gemm_kernel( - a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], - ): - c_reg = tkl.Register[M, N, tkl.f32](0.0) - - @tkw.iterate(K, init_args=[c_reg]) - def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: - a_reg = tkw.read(a) - b_reg = tkw.read(b) - acc = tkw.mma(a_reg, b_reg, acc) - return acc - - tkw.write(repeat, c) - - m, n, k = 256, 256, 256 - block_k = 16 - - subs = { - M: m, - N: n, - K: k, - BLOCK_M: block_m, - BLOCK_N: block_n, - BLOCK_K: block_k, - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, - } - subs.update(get_default_scheduling_params()) - - options = WaveCompileOptions( - subs=subs, - canonicalize=True, - schedule=SchedulingType.PREFETCH, - use_scheduling_barriers=True, - backend="asm", - wave_runtime=True, - compile_to_mlir=False, - use_global_to_shared=False, + gemm, options = get_tagged_gemm( + shape=(4096, 4096, 4096), + block_shape=(128, 256, 64), ) - options = set_default_run_config(options) - - # Capture MLIR via the same path as the e2e tests. - with IndexingContext() as idxc: - idxc.set_subs(options.subs) - gemm_kernel.initialize_wave_constraints() - gemm_kernel.initialize_symbolic_constraints() - gemm_kernel.initialize_workgroup_constraints() + schedule = get_two_pp_cluster_schedule() - result = _trace_launchable_and_get_kernel_signature(gemm_kernel, options) - mb = result[0] - - if options.canonicalize: - canonicalize_module(mb.module_op) - - full_mlir = mb.module_op.get_asm(enable_debug_info=False) - - launch_info = options.kernel_launch_info - blocks = launch_info.blocks if launch_info.blocks else [64, 1, 1] - - # Extract func.func from stream wrapper. - with Context() as ctx: - ctx.allow_unregistered_dialects = True - module = Module.parse(full_mlir) + options.backend = "asm" + options.wave_runtime = True + options.compile_to_mlir = False + options = set_default_run_config(options) - for fn in walk_ops_recursively(module.operation): - if not isinstance(fn, func_d.FuncOp): - continue - if should_skip_function(fn): - continue - func_text = fn.get_asm(print_generic_op_form=True) - mlir_text = "module {\n" + func_text + "\n}\n" - return mlir_text, tuple(blocks) + sys.path.insert( + 0, + str( + wave_root + / "wave_lang" + / "kernel" + / "wave" + / "asm" + / "wave_asm" + / "test" + / "e2e" + ), + ) + from waveasm_e2e import capture_wave_kernel_info - raise ValueError("No kernel function found in MLIR") + info = capture_wave_kernel_info(options, gemm, schedule=schedule) + return info.mlir_text, info.workgroup_size def run_waveasm_translate( From 7c935def7aee812360c1133ea90332c01ff062ca Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 12:12:50 +0100 Subject: [PATCH 35/49] less logs Signed-off-by: Ivan Butygin --- conductor/llm.py | 2 +- .../waveasm-conductor/waveasm-conductor.cpp | 21 +++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index 21fc4d339..0d1402a72 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -430,7 +430,7 @@ def _evaluate_moves(moves: list[str]) -> str: metrics = conductor.get_metrics(asm) except RuntimeError as e: log(f" [error] {e}\n") - log(f" --- Faulty IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") + # log(f" --- Faulty IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") return json.dumps({"error": str(e)}) log(f" [result] {metrics}\n") ir_diff = _context_diff(tagged, reordered_ir) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp index 2cb07f511..6abee5bae 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -47,6 +47,11 @@ static llvm::cl::opt printDebugLocsInline( llvm::cl::desc("Print location information inline (pretty form)"), llvm::cl::init(false)); +static llvm::cl::opt + dumpIROnFailure("dump-ir-on-failure", + llvm::cl::desc("Dump IR to stderr on move or verify failure"), + llvm::cl::init(false)); + //===----------------------------------------------------------------------===// // Main Function //===----------------------------------------------------------------------===// @@ -107,18 +112,22 @@ int main(int argc, char **argv) { if (!result.success) { llvm::errs() << "conductor: command " << result.failedCommand << ": " << result.error << "\n"; - llvm::errs() << "--- IR after partial moves ---\n"; - module->print(llvm::errs()); - llvm::errs() << "--- end IR ---\n"; + if (dumpIROnFailure) { + llvm::errs() << "--- IR after partial moves ---\n"; + module->print(llvm::errs()); + llvm::errs() << "--- end IR ---\n"; + } return 1; } // Verify the module after moves (catches broken dominance, etc.). if (failed(mlir::verify(*module))) { llvm::errs() << "conductor: verification failed after applying moves\n"; - llvm::errs() << "--- IR at verification failure ---\n"; - module->print(llvm::errs()); - llvm::errs() << "--- end IR ---\n"; + if (dumpIROnFailure) { + llvm::errs() << "--- IR at verification failure ---\n"; + module->print(llvm::errs()); + llvm::errs() << "--- end IR ---\n"; + } return 1; } From f0febd7d32cac315404dbc6e21cde736b9b196bf Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 12:28:42 +0100 Subject: [PATCH 36/49] clang-format Signed-off-by: Ivan Butygin --- .../tools/waveasm-conductor/waveasm-conductor.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp index 6abee5bae..44edabca3 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -47,10 +47,10 @@ static llvm::cl::opt printDebugLocsInline( llvm::cl::desc("Print location information inline (pretty form)"), llvm::cl::init(false)); -static llvm::cl::opt - dumpIROnFailure("dump-ir-on-failure", - llvm::cl::desc("Dump IR to stderr on move or verify failure"), - llvm::cl::init(false)); +static llvm::cl::opt dumpIROnFailure( + "dump-ir-on-failure", + llvm::cl::desc("Dump IR to stderr on move or verify failure"), + llvm::cl::init(false)); //===----------------------------------------------------------------------===// // Main Function From 8885ad0a05dff91d4167cf01e559b468ba1cd38f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 12:28:55 +0100 Subject: [PATCH 37/49] prompt and less code Signed-off-by: Ivan Butygin --- conductor/llm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index 0d1402a72..d653f0da2 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -162,6 +162,11 @@ def _noop_log(_msg: str) -> None: Constraints: - All moves must stay within the same basic block. Never move across blocks. - Moves must preserve SSA dominance: a value must be defined before all its uses. + Before proposing any move, trace the SSA def-use chains for the instruction + you want to move. Check: (1) every operand (%val) it consumes is still + defined above it after the move, and (2) every result it produces is still + above all its users after the move. If either check fails, the move is + illegal — pick a different one. - Pinned ops (s_endpgm, s_barrier, condition) cannot be moved. Work incrementally: try 1-3 moves per tool call, read the resulting metrics, \ @@ -433,9 +438,9 @@ def _evaluate_moves(moves: list[str]) -> str: # log(f" --- Faulty IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") return json.dumps({"error": str(e)}) log(f" [result] {metrics}\n") - ir_diff = _context_diff(tagged, reordered_ir) - if ir_diff: - log(f" --- IR diff ---\n{ir_diff} --- End diff ---\n") + # ir_diff = _context_diff(tagged, reordered_ir) + # if ir_diff: + # log(f" --- IR diff ---\n{ir_diff} --- End diff ---\n") improved = _is_better(metrics, best_metrics) if improved: log(" Improvement!\n") From a49ab7aa0ca68cbe92ef5fe5fb85563ef9e24b2b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 13:06:55 +0100 Subject: [PATCH 38/49] actually change default efforts Signed-off-by: Ivan Butygin --- conductor/conductor.py | 4 ++-- conductor/llm.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/conductor/conductor.py b/conductor/conductor.py index 842d32a33..8d1abc606 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -214,8 +214,8 @@ def main(): parser.add_argument( "--reasoning-effort", type=str, - default="high", - help="Reasoning effort for models that support it (default: high).", + default="low", + help="Reasoning effort for models that support it (default: low).", ) parser.add_argument( "--kernel", diff --git a/conductor/llm.py b/conductor/llm.py index d653f0da2..e0b10ea04 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -172,6 +172,12 @@ def _noop_log(_msg: str) -> None: Work incrementally: try 1-3 moves per tool call, read the resulting metrics, \ then decide your next moves. You can call the tool multiple times. +DO NOT try to build too long of a plan at once. Try 1-3 moves and see what changes. + +Keep your reasoning brief. For each thinking step, write at most one short \ +sentence. Do not re-analyze the entire IR — focus only on the specific move \ +you are considering and its immediate SSA neighbors. + When you are satisfied with the schedule or have no more ideas, call the `done` \ tool to finish.\ """ @@ -395,7 +401,7 @@ def run_scheduling_loop( max_rounds: int = 10, model: str = DEFAULT_MODEL, temperature: float = 0.7, - reasoning_effort: str | None = "medium", + reasoning_effort: str | None = "high", log: Callable[[str], None] = _default_log, ) -> dict: """ From df7e2f043725132be78867cf63b323d2c4afb24d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 13:17:55 +0100 Subject: [PATCH 39/49] proper dominance report Signed-off-by: Ivan Butygin --- .../wave_asm/lib/Transforms/ApplyMoves.cpp | 51 +++++++++++++++++++ .../apply-moves-error-dominance.mlir | 3 +- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp index aa77b586b..513b70cb8 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/ApplyMoves.cpp @@ -10,6 +10,7 @@ #include "waveasm/Transforms/ApplyMoves.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Location.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" @@ -52,6 +53,44 @@ std::string validateSameBlock(Operation *a, StringRef tagA, Operation *b, return ""; } +/// Get a human-readable name for an operation (tag or mnemonic). +std::string opName(Operation *op) { + if (auto nameLoc = dyn_cast(op->getLoc())) + return nameLoc.getName().str(); + return op->getName().getStringRef().str(); +} + +/// Check SSA dominance for a single op after it has been moved. +/// Verifies: (1) all operands are defined above the op, +/// (2) all results are defined before their users. +std::string checkDominance(Operation *op) { + DominanceInfo domInfo(op->getParentOfType()); + + // Check that every operand still dominates this op. + for (Value operand : op->getOperands()) { + if (!domInfo.properlyDominates(operand, op)) { + std::string defName = "block-arg"; + if (auto *defOp = operand.getDefiningOp()) + defName = opName(defOp); + return "moving '" + opName(op) + "' breaks dominance: operand from '" + + defName + "' no longer defined before use"; + } + } + + // Check that every user of this op's results is still dominated. + for (Value result : op->getResults()) { + for (Operation *user : result.getUsers()) { + if (!domInfo.properlyDominates(op, user)) { + return "moving '" + opName(op) + + "' breaks dominance: result used by '" + opName(user) + + "' which now appears before the definition"; + } + } + } + + return ""; +} + } // namespace namespace waveasm { @@ -167,6 +206,9 @@ MoveResult applyMoves(ModuleOp module, llvm::ArrayRef commands) { return fail(err); op1->moveBefore(op2); LDBG() << "move " << move->tag << " before " << move->refTag; + err = checkDominance(op1); + if (!err.empty()) + return fail(err); } else if (auto *move = std::get_if(&cmd)) { std::string err = @@ -175,6 +217,9 @@ MoveResult applyMoves(ModuleOp module, llvm::ArrayRef commands) { return fail(err); op1->moveAfter(op2); LDBG() << "move " << move->tag << " after " << move->refTag; + err = checkDominance(op1); + if (!err.empty()) + return fail(err); } else if (auto *swap = std::get_if(&cmd)) { std::string err = @@ -196,6 +241,12 @@ MoveResult applyMoves(ModuleOp module, llvm::ArrayRef commands) { op2->moveBefore(op2->getBlock(), op2->getBlock()->end()); } LDBG() << "swap " << swap->tag1 << " " << swap->tag2; + err = checkDominance(op1); + if (!err.empty()) + return fail(err); + err = checkDominance(op2); + if (!err.empty()) + return fail(err); } } diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir index b3d01551d..7eaa9a751 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir @@ -5,8 +5,7 @@ // v_add_u32_1 uses the result of v_add_u32_0, so moving _0 after _1 // breaks dominance. -// CHECK: does not dominate this use -// CHECK: conductor: verification failed after applying moves +// CHECK: moving 'v_add_u32_0' breaks dominance: result used by 'v_add_u32_1' which now appears before the definition waveasm.program @test_dominance target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> From 5db516029c83fbd87418ffb4b0313544e557ff75 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 13:42:05 +0100 Subject: [PATCH 40/49] stateful Signed-off-by: Ivan Butygin --- conductor/llm.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index e0b10ea04..99715c47d 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -417,22 +417,21 @@ def run_scheduling_loop( baseline = conductor.baseline() log(f" baseline: {baseline}\n") - tagged_ir = conductor.tag() - log(f" --- Tagged IR ---\n{tagged_ir.strip()}\n --- End IR ---\n") + current_ir = conductor.tag() + log(f" --- Tagged IR ---\n{current_ir.strip()}\n --- End IR ---\n") best_metrics = dict(baseline) - best_commands: list[str] = [] + all_commands: list[str] = [] finished = False # Build tool registry with closures over loop state. registry = ToolRegistry() def _evaluate_moves(moves: list[str]) -> str: - nonlocal best_metrics, best_commands + nonlocal best_metrics, all_commands, current_ir log(f" [tool] evaluate_moves({moves})\n") try: - tagged = conductor.tag() - reordered_ir = conductor.apply_moves(tagged, moves) + reordered_ir = conductor.apply_moves(current_ir, moves) except RuntimeError as e: log(f" [error] {e}\n") return json.dumps({"error": str(e)}) @@ -441,17 +440,15 @@ def _evaluate_moves(moves: list[str]) -> str: metrics = conductor.get_metrics(asm) except RuntimeError as e: log(f" [error] {e}\n") - # log(f" --- Faulty IR ---\n{reordered_ir.strip()}\n --- End IR ---\n") return json.dumps({"error": str(e)}) log(f" [result] {metrics}\n") - # ir_diff = _context_diff(tagged, reordered_ir) - # if ir_diff: - # log(f" --- IR diff ---\n{ir_diff} --- End diff ---\n") + # Moves succeeded — update state. + current_ir = reordered_ir + all_commands.extend(moves) improved = _is_better(metrics, best_metrics) if improved: log(" Improvement!\n") best_metrics = metrics - best_commands = moves return json.dumps( { "metrics": metrics, @@ -503,7 +500,7 @@ def _done(summary: str) -> str: { "role": "user", "content": format_initial_prompt( - tagged_ir, baseline, target=conductor.target + current_ir, baseline, target=conductor.target ), }, ] @@ -568,7 +565,7 @@ def _done(summary: str) -> str: return { "metrics": best_metrics, - "commands": best_commands, + "commands": all_commands, "rounds": round_num, "baseline_metrics": baseline, "usage": usage, From 02f0406e6cc27de9f404c84261ac3e6e08992457 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 17:36:30 +0100 Subject: [PATCH 41/49] ssa names Signed-off-by: Ivan Butygin --- conductor/conductor.py | 12 ++++++++++-- .../tools/waveasm-conductor/waveasm-conductor.cpp | 7 +++++++ .../tools/waveasm-translate/waveasm-translate.cpp | 7 +++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/conductor/conductor.py b/conductor/conductor.py index 8d1abc606..a72ba9a46 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -88,6 +88,7 @@ def tag(self) -> str: flags = [ "--waveasm-tag-instructions", "--print-debug-locs-inline", + "--use-nameloc-as-prefix", ] stdout, stderr, rc = run_waveasm_translate( self.waveasm_ir, self.workgroup_size, flags @@ -109,7 +110,12 @@ def apply_moves(self, tagged_ir: str, commands: list) -> str: input_path = f.name try: - cmd = [conductor, "--print-debug-locs-inline", input_path] + cmd = [ + conductor, + "--print-debug-locs-inline", + "--use-nameloc-as-prefix", + input_path, + ] result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) if result.returncode != 0: raise RuntimeError( @@ -250,7 +256,9 @@ def main(): waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size) print(f" WaveASM IR: {len(waveasm_ir)} chars", file=sys.stderr) - conductor = Conductor(waveasm_ir, wg_size) + from conductor.extract_ir import get_target + + conductor = Conductor(waveasm_ir, wg_size, target=get_target()) if args.tag_only: print(conductor.tag()) diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp index 44edabca3..9c7da35a8 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -52,6 +52,11 @@ static llvm::cl::opt dumpIROnFailure( llvm::cl::desc("Dump IR to stderr on move or verify failure"), llvm::cl::init(false)); +static llvm::cl::opt useNameLocAsPrefix( + "use-nameloc-as-prefix", + llvm::cl::desc("Print SSA IDs using NameLocs as prefixes"), + llvm::cl::init(false)); + //===----------------------------------------------------------------------===// // Main Function //===----------------------------------------------------------------------===// @@ -144,6 +149,8 @@ int main(int argc, char **argv) { flags.enableDebugInfo(/*prettyForm=*/true); flags.useLocalScope(); } + if (useNameLocAsPrefix) + flags.printNameLocAsPrefix(); module->print(outputStream, flags); return 0; diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp index 0d7b68f3e..3b30845e8 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-translate/waveasm-translate.cpp @@ -115,6 +115,11 @@ static llvm::cl::opt printDebugLocsInline( llvm::cl::desc("Print location information inline (pretty form)"), llvm::cl::init(false)); +static llvm::cl::opt useNameLocAsPrefix( + "use-nameloc-as-prefix", + llvm::cl::desc("Print SSA IDs using NameLocs as prefixes"), + llvm::cl::init(false)); + static llvm::cl::opt runPreTranslationCSE( "mlir-cse", llvm::cl::desc("Run MLIR CSE before translation (reduces redundant index " @@ -358,6 +363,8 @@ int main(int argc, char **argv) { } else if (printDebugLocs) { flags.enableDebugInfo(); } + if (useNameLocAsPrefix) + flags.printNameLocAsPrefix(); module->print(outputStream, flags); return 0; From ec9e39d685e57766cddbfb313976c95210f090fe Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 17:42:45 +0100 Subject: [PATCH 42/49] print diff Signed-off-by: Ivan Butygin --- conductor/llm.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/conductor/llm.py b/conductor/llm.py index 99715c47d..55f6cb282 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -417,7 +417,8 @@ def run_scheduling_loop( baseline = conductor.baseline() log(f" baseline: {baseline}\n") - current_ir = conductor.tag() + initial_ir = conductor.tag() + current_ir = initial_ir log(f" --- Tagged IR ---\n{current_ir.strip()}\n --- End IR ---\n") best_metrics = dict(baseline) @@ -563,6 +564,12 @@ def _done(summary: str) -> str: f" cost: ${usage.cost:.4f}\n" ) + ir_diff = _context_diff(initial_ir, current_ir) + if ir_diff: + log(f"\n=== IR Diff (original → final) ===\n{ir_diff}\n") + else: + log("\n=== No IR changes ===\n") + return { "metrics": best_metrics, "commands": all_commands, From 7b0c013eb20200a56c471d016a895d9e358c1343 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 17:53:00 +0100 Subject: [PATCH 43/49] eveluate_moved summary Signed-off-by: Ivan Butygin --- conductor/llm.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index 55f6cb282..a7787379c 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -428,8 +428,10 @@ def run_scheduling_loop( # Build tool registry with closures over loop state. registry = ToolRegistry() - def _evaluate_moves(moves: list[str]) -> str: + def _evaluate_moves(moves: list[str], summary: str = "") -> str: nonlocal best_metrics, all_commands, current_ir + if summary: + log(f" [reason] {summary}\n") log(f" [tool] evaluate_moves({moves})\n") try: reordered_ir = conductor.apply_moves(current_ir, moves) @@ -450,13 +452,14 @@ def _evaluate_moves(moves: list[str]) -> str: if improved: log(" Improvement!\n") best_metrics = metrics - return json.dumps( - { - "metrics": metrics, - "improved": improved or metrics == best_metrics, - "updated_ir": reordered_ir.strip(), - } - ) + result = { + "metrics": metrics, + "improved": improved or metrics == best_metrics, + "updated_ir": reordered_ir.strip(), + } + if summary: + result["summary"] = summary + return json.dumps(result) def _done(summary: str) -> str: nonlocal finished @@ -481,6 +484,13 @@ def _done(summary: str) -> str: type="array", items={"type": "string"}, ), + Param( + name="summary", + description=( + "Brief explanation of why these moves should help " + "(e.g. 'hide LDS latency by interleaving ds_read with mfma')." + ), + ), ], func=_evaluate_moves, ) From 4088c7cde74234bcfba14df5f1f43d6feeacc1b1 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 18:12:10 +0100 Subject: [PATCH 44/49] move openrouter to providers Signed-off-by: Ivan Butygin --- conductor/__init__.py | 3 +- conductor/conductor.py | 3 +- conductor/llm.py | 311 +----------------------------- conductor/providers/__init__.py | 7 + conductor/providers/openrouter.py | 309 +++++++++++++++++++++++++++++ 5 files changed, 327 insertions(+), 306 deletions(-) create mode 100644 conductor/providers/__init__.py create mode 100644 conductor/providers/openrouter.py diff --git a/conductor/__init__.py b/conductor/__init__.py index 5510d001f..0d106741a 100644 --- a/conductor/__init__.py +++ b/conductor/__init__.py @@ -16,7 +16,8 @@ capture_kernel_mlir, capture_mxfp4_kernel_mlir, ) -from conductor.llm import run_scheduling_loop, Stats, Counters +from conductor.llm import run_scheduling_loop +from conductor.providers.openrouter import Stats, Counters from conductor.tools import Param, ToolDef, ToolRegistry __all__ = [ diff --git a/conductor/conductor.py b/conductor/conductor.py index a72ba9a46..35ac73219 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -265,7 +265,8 @@ def main(): return if args.llm: - from conductor.llm import run_scheduling_loop, DEFAULT_MODEL + from conductor.llm import run_scheduling_loop + from conductor.providers.openrouter import DEFAULT_MODEL model = args.model or DEFAULT_MODEL print(f"Running LLM scheduling loop (model={model})...", file=sys.stderr) diff --git a/conductor/llm.py b/conductor/llm.py index a7787379c..8d524c488 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -1,135 +1,24 @@ -"""OpenRouter LLM client and iterative scheduling loop for Conductor.""" +"""Iterative LLM scheduling loop for Conductor.""" import difflib import json -import os import sys -import threading -import time -from collections import defaultdict from collections.abc import Callable -from dataclasses import dataclass -from types import TracebackType -from typing import Any - -import requests +from conductor.providers.openrouter import ( + DEFAULT_MODEL, + Message, + Stats, + chat, +) from conductor.tools import Param, ToolRegistry -API_KEY: str = os.environ.get("OPENROUTER_API_KEY", "") -BASE_URL: str = "https://openrouter.ai/api/v1" -DEFAULT_MODEL: str = "deepseek/deepseek-v3.2" - -_REQUEST_TIMEOUT = 120 -_MAX_RETRIES = 3 -_RETRY_BACKOFF = 2.0 - -Message = dict[str, Any] - - -# --- Monotonic API usage counters (thread-safe, per-model). --- - - -@dataclass -class Counters: - """API usage snapshot.""" - - tokens: int = 0 - input_tokens: int = 0 - output_tokens: int = 0 - cost: float = 0.0 - - -_counters_lock = threading.Lock() -_counters: defaultdict[str, Counters] = defaultdict(Counters) - - -def _record_usage(model: str, usage: dict[str, Any]) -> None: - """Accumulate token and cost from an API response.""" - tokens = int(usage.get("total_tokens", 0)) - input_tokens = int(usage.get("prompt_tokens", 0)) - output_tokens = int(usage.get("completion_tokens", 0)) - cost = usage.get("cost") - with _counters_lock: - c = _counters[model] - c.tokens += tokens - c.input_tokens += input_tokens - c.output_tokens += output_tokens - if cost is not None: - c.cost += float(cost) - - -class Stats: - """Context manager that captures API usage over a scope. - - Snapshots the monotonic per-model counters on entry. The ``counters`` - property returns an aggregate delta; ``per_model`` returns per-model deltas. - """ - - def __init__(self) -> None: - self._start: dict[str, Counters] = {} - - def __enter__(self) -> "Stats": - with _counters_lock: - self._start = { - m: Counters(c.tokens, c.input_tokens, c.output_tokens, c.cost) - for m, c in _counters.items() - } - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - pass - - @property - def counters(self) -> Counters: - """Aggregate delta across all models since entering the context.""" - with _counters_lock: - tok = sum(c.tokens for c in _counters.values()) - sum( - c.tokens for c in self._start.values() - ) - inp = sum(c.input_tokens for c in _counters.values()) - sum( - c.input_tokens for c in self._start.values() - ) - out = sum(c.output_tokens for c in _counters.values()) - sum( - c.output_tokens for c in self._start.values() - ) - cst = sum(c.cost for c in _counters.values()) - sum( - c.cost for c in self._start.values() - ) - return Counters(tokens=tok, input_tokens=inp, output_tokens=out, cost=cst) - - @property - def per_model(self) -> defaultdict[str, Counters]: - """Per-model deltas since entering the context.""" - with _counters_lock: - result: defaultdict[str, Counters] = defaultdict(Counters) - for model, current in _counters.items(): - start = self._start.get(model, Counters()) - delta = Counters( - tokens=current.tokens - start.tokens, - input_tokens=current.input_tokens - start.input_tokens, - output_tokens=current.output_tokens - start.output_tokens, - cost=current.cost - start.cost, - ) - if delta.tokens or delta.cost: - result[model] = delta - return result - def _default_log(msg: str) -> None: """Default logger: print to stderr without trailing newline.""" print(msg, file=sys.stderr, end="", flush=True) -def _noop_log(_msg: str) -> None: - pass - - SYSTEM_PROMPT = """\ You are an expert GPU instruction scheduler for AMD CDNA architectures. @@ -183,192 +72,6 @@ def _noop_log(_msg: str) -> None: """ -_TRANSIENT_ERRORS = ( - requests.exceptions.ChunkedEncodingError, - requests.exceptions.ConnectionError, - requests.exceptions.ReadTimeout, - requests.exceptions.ConnectTimeout, -) - - -def _with_retry( - func: Callable[..., Any], - log: Callable[[str], None], - *args: Any, - **kwargs: Any, -) -> Any: - """Call *func* with retries on transient network errors and 5xx.""" - for attempt in range(_MAX_RETRIES): - try: - return func(*args, **kwargs) - except requests.HTTPError as exc: - resp = exc.response - if resp is not None and resp.status_code >= 500: - if attempt < _MAX_RETRIES - 1: - wait = _RETRY_BACKOFF * (attempt + 1) - log( - f"\n [retry] server {resp.status_code}, waiting {wait:.0f}s...\n" - ) - time.sleep(wait) - continue - raise - except _TRANSIENT_ERRORS: - if attempt < _MAX_RETRIES - 1: - wait = _RETRY_BACKOFF * (attempt + 1) - log(f"\n [retry] connection error, waiting {wait:.0f}s...\n") - time.sleep(wait) - else: - raise - raise RuntimeError("Unreachable") - - -def _stream_request( - payload: dict[str, Any], - on_token: Callable[[str], None] | None, - on_thinking: Callable[[str], None] | None, -) -> Message: - """Execute a streaming chat request and assemble the response message.""" - resp = requests.post( - f"{BASE_URL}/chat/completions", - headers={"Authorization": f"Bearer {API_KEY}"}, - json=payload, - stream=True, - timeout=_REQUEST_TIMEOUT, - ) - if resp.status_code >= 400: - raise requests.HTTPError( - f"{resp.status_code} {resp.reason}: {resp.text}", - response=resp, - ) - - content_chunks: list[str] = [] - reasoning_chunks: list[str] = [] - tool_calls_by_index: dict[int, dict[str, Any]] = {} - usage: dict[str, Any] | None = None - - for line in resp.iter_lines(): - if not line or not line.startswith(b"data: "): - continue - data = line[6:] - if data == b"[DONE]": - break - chunk = json.loads(data) - if "usage" in chunk: - usage = chunk["usage"] - choices = chunk.get("choices") - if not choices: - continue - delta = choices[0].get("delta", {}) - - # Reasoning tokens (two OpenRouter formats). - reasoning_texts: list[str] = [] - for detail in delta.get("reasoning_details", []): - text = detail.get("text", "") - if text: - reasoning_texts.append(text) - rc = delta.get("reasoning_content", "") - if rc: - reasoning_texts.append(rc) - for text in reasoning_texts: - if on_thinking is not None: - on_thinking(text) - reasoning_chunks.append(text) - - # Content tokens. - token = delta.get("content", "") - if token: - if on_token is not None: - on_token(token) - content_chunks.append(token) - - # Tool call deltas. - for tc_delta in delta.get("tool_calls", []): - idx = tc_delta["index"] - if idx not in tool_calls_by_index: - tool_calls_by_index[idx] = { - "id": tc_delta.get("id", ""), - "type": "function", - "function": {"name": "", "arguments": ""}, - } - tc = tool_calls_by_index[idx] - func_delta = tc_delta.get("function", {}) - if func_delta.get("name"): - tc["function"]["name"] += func_delta["name"] - if func_delta.get("arguments"): - tc["function"]["arguments"] += func_delta["arguments"] - - result: Message = {"role": "assistant", "content": "".join(content_chunks)} - if reasoning_chunks: - result["reasoning"] = "".join(reasoning_chunks) - if tool_calls_by_index: - result["tool_calls"] = [ - tool_calls_by_index[i] for i in sorted(tool_calls_by_index) - ] - if usage is not None: - result["usage"] = usage - return result - - -def chat( - messages: list[Message], - model: str, - temperature: float = 0.7, - max_tokens: int = 2048, - reasoning_effort: str | None = None, - tools: list[dict] | None = None, - log: Callable[[str], None] = _default_log, -) -> Message: - """Send a chat completion request. Returns the full response message dict. - - Handles payload construction, retries on transient errors, usage - recording, and streaming log output. - """ - if not API_KEY: - raise RuntimeError( - "OPENROUTER_API_KEY not set. Export it before running the LLM loop." - ) - - payload: dict[str, Any] = { - "model": model, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens, - "stream": True, - "stream_options": {"include_usage": True}, - } - if reasoning_effort is not None: - payload["reasoning"] = {"enabled": True, "effort": reasoning_effort} - if tools: - payload["tools"] = [dict(t) for t in tools] - - # Streaming callbacks that manage [thinking]/[/thinking] delimiters. - in_reasoning = False - - def on_thinking(text: str) -> None: - nonlocal in_reasoning - if not in_reasoning: - log("\n [thinking] ") - in_reasoning = True - log(text) - - def on_token(text: str) -> None: - nonlocal in_reasoning - if in_reasoning: - log("\n [/thinking]\n") - in_reasoning = False - - result: Message = _with_retry(_stream_request, log, payload, on_token, on_thinking) - if in_reasoning: - log("\n [/thinking]\n") - - if "usage" in result: - _record_usage(model, result["usage"]) - pt = result["usage"].get("prompt_tokens", "?") - ct = result["usage"].get("completion_tokens", "?") - log(f"\n [tokens] prompt={pt} completion={ct}\n") - return result - - def format_initial_prompt( tagged_ir: str, baseline_metrics: dict, diff --git a/conductor/providers/__init__.py b/conductor/providers/__init__.py new file mode 100644 index 000000000..423b576ea --- /dev/null +++ b/conductor/providers/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""LLM provider backends for Conductor.""" diff --git a/conductor/providers/openrouter.py b/conductor/providers/openrouter.py new file mode 100644 index 000000000..76751d3e8 --- /dev/null +++ b/conductor/providers/openrouter.py @@ -0,0 +1,309 @@ +"""OpenRouter API client for Conductor. + +Handles streaming chat completions, retries, and usage tracking. +""" + +import json +import os +import threading +import time +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from types import TracebackType +from typing import Any + +import requests + +API_KEY: str = os.environ.get("OPENROUTER_API_KEY", "") +BASE_URL: str = "https://openrouter.ai/api/v1" +DEFAULT_MODEL: str = "deepseek/deepseek-v3.2" + +_REQUEST_TIMEOUT = 120 +_MAX_RETRIES = 3 +_RETRY_BACKOFF = 2.0 + +Message = dict[str, Any] + + +# --- Monotonic API usage counters (thread-safe, per-model). --- + + +@dataclass +class Counters: + """API usage snapshot.""" + + tokens: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + cost: float = 0.0 + + +_counters_lock = threading.Lock() +_counters: defaultdict[str, Counters] = defaultdict(Counters) + + +def _record_usage(model: str, usage: dict[str, Any]) -> None: + """Accumulate token and cost from an API response.""" + tokens = int(usage.get("total_tokens", 0)) + input_tokens = int(usage.get("prompt_tokens", 0)) + output_tokens = int(usage.get("completion_tokens", 0)) + cost = usage.get("cost") + with _counters_lock: + c = _counters[model] + c.tokens += tokens + c.input_tokens += input_tokens + c.output_tokens += output_tokens + if cost is not None: + c.cost += float(cost) + + +class Stats: + """Context manager that captures API usage over a scope. + + Snapshots the monotonic per-model counters on entry. The ``counters`` + property returns an aggregate delta; ``per_model`` returns per-model deltas. + """ + + def __init__(self) -> None: + self._start: dict[str, Counters] = {} + + def __enter__(self) -> "Stats": + with _counters_lock: + self._start = { + m: Counters(c.tokens, c.input_tokens, c.output_tokens, c.cost) + for m, c in _counters.items() + } + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + pass + + @property + def counters(self) -> Counters: + """Aggregate delta across all models since entering the context.""" + with _counters_lock: + tok = sum(c.tokens for c in _counters.values()) - sum( + c.tokens for c in self._start.values() + ) + inp = sum(c.input_tokens for c in _counters.values()) - sum( + c.input_tokens for c in self._start.values() + ) + out = sum(c.output_tokens for c in _counters.values()) - sum( + c.output_tokens for c in self._start.values() + ) + cst = sum(c.cost for c in _counters.values()) - sum( + c.cost for c in self._start.values() + ) + return Counters(tokens=tok, input_tokens=inp, output_tokens=out, cost=cst) + + @property + def per_model(self) -> defaultdict[str, Counters]: + """Per-model deltas since entering the context.""" + with _counters_lock: + result: defaultdict[str, Counters] = defaultdict(Counters) + for model, current in _counters.items(): + start = self._start.get(model, Counters()) + delta = Counters( + tokens=current.tokens - start.tokens, + input_tokens=current.input_tokens - start.input_tokens, + output_tokens=current.output_tokens - start.output_tokens, + cost=current.cost - start.cost, + ) + if delta.tokens or delta.cost: + result[model] = delta + return result + + +_TRANSIENT_ERRORS = ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + requests.exceptions.ConnectTimeout, +) + + +def _with_retry( + func: Callable[..., Any], + log: Callable[[str], None], + *args: Any, + **kwargs: Any, +) -> Any: + """Call *func* with retries on transient network errors and 5xx.""" + for attempt in range(_MAX_RETRIES): + try: + return func(*args, **kwargs) + except requests.HTTPError as exc: + resp = exc.response + if resp is not None and resp.status_code >= 500: + if attempt < _MAX_RETRIES - 1: + wait = _RETRY_BACKOFF * (attempt + 1) + log( + f"\n [retry] server {resp.status_code}, waiting {wait:.0f}s...\n" + ) + time.sleep(wait) + continue + raise + except _TRANSIENT_ERRORS: + if attempt < _MAX_RETRIES - 1: + wait = _RETRY_BACKOFF * (attempt + 1) + log(f"\n [retry] connection error, waiting {wait:.0f}s...\n") + time.sleep(wait) + else: + raise + raise RuntimeError("Unreachable") + + +def _stream_request( + payload: dict[str, Any], + on_token: Callable[[str], None] | None, + on_thinking: Callable[[str], None] | None, +) -> Message: + """Execute a streaming chat request and assemble the response message.""" + resp = requests.post( + f"{BASE_URL}/chat/completions", + headers={"Authorization": f"Bearer {API_KEY}"}, + json=payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + if resp.status_code >= 400: + raise requests.HTTPError( + f"{resp.status_code} {resp.reason}: {resp.text}", + response=resp, + ) + + content_chunks: list[str] = [] + reasoning_chunks: list[str] = [] + tool_calls_by_index: dict[int, dict[str, Any]] = {} + usage: dict[str, Any] | None = None + + for line in resp.iter_lines(): + if not line or not line.startswith(b"data: "): + continue + data = line[6:] + if data == b"[DONE]": + break + chunk = json.loads(data) + if "usage" in chunk: + usage = chunk["usage"] + choices = chunk.get("choices") + if not choices: + continue + delta = choices[0].get("delta", {}) + + # Reasoning tokens (two OpenRouter formats). + reasoning_texts: list[str] = [] + for detail in delta.get("reasoning_details", []): + text = detail.get("text", "") + if text: + reasoning_texts.append(text) + rc = delta.get("reasoning_content", "") + if rc: + reasoning_texts.append(rc) + for text in reasoning_texts: + if on_thinking is not None: + on_thinking(text) + reasoning_chunks.append(text) + + # Content tokens. + token = delta.get("content", "") + if token: + if on_token is not None: + on_token(token) + content_chunks.append(token) + + # Tool call deltas. + for tc_delta in delta.get("tool_calls", []): + idx = tc_delta["index"] + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": tc_delta.get("id", ""), + "type": "function", + "function": {"name": "", "arguments": ""}, + } + tc = tool_calls_by_index[idx] + func_delta = tc_delta.get("function", {}) + if func_delta.get("name"): + tc["function"]["name"] += func_delta["name"] + if func_delta.get("arguments"): + tc["function"]["arguments"] += func_delta["arguments"] + + result: Message = {"role": "assistant", "content": "".join(content_chunks)} + if reasoning_chunks: + result["reasoning"] = "".join(reasoning_chunks) + if tool_calls_by_index: + result["tool_calls"] = [ + tool_calls_by_index[i] for i in sorted(tool_calls_by_index) + ] + if usage is not None: + result["usage"] = usage + return result + + +def chat( + messages: list[Message], + model: str, + temperature: float = 0.7, + max_tokens: int = 2048, + reasoning_effort: str | None = None, + tools: list[dict] | None = None, + log: Callable[[str], None] | None = None, +) -> Message: + """Send a chat completion request. Returns the full response message dict. + + Handles payload construction, retries on transient errors, usage + recording, and streaming log output. + """ + if not API_KEY: + raise RuntimeError( + "OPENROUTER_API_KEY not set. Export it before running the LLM loop." + ) + + if log is None: + log = lambda _: None + + payload: dict[str, Any] = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + if reasoning_effort is not None: + payload["reasoning"] = {"enabled": True, "effort": reasoning_effort} + if tools: + payload["tools"] = [dict(t) for t in tools] + + # Streaming callbacks that manage [thinking]/[/thinking] delimiters. + in_reasoning = False + + def on_thinking(text: str) -> None: + nonlocal in_reasoning + if not in_reasoning: + log("\n [thinking] ") + in_reasoning = True + log(text) + + def on_token(text: str) -> None: + nonlocal in_reasoning + if in_reasoning: + log("\n [/thinking]\n") + in_reasoning = False + + result: Message = _with_retry(_stream_request, log, payload, on_token, on_thinking) + if in_reasoning: + log("\n [/thinking]\n") + + if "usage" in result: + _record_usage(model, result["usage"]) + pt = result["usage"].get("prompt_tokens", "?") + ct = result["usage"].get("completion_tokens", "?") + log(f"\n [tokens] prompt={pt} completion={ct}\n") + return result From 60eb6e7b77969775ff398f706be638a4f9718b83 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 20:30:52 +0100 Subject: [PATCH 45/49] cursor agent as provider Signed-off-by: Ivan Butygin --- conductor/conductor.py | 17 ++- conductor/llm.py | 74 ++++++---- conductor/providers/cursor_agent.py | 211 ++++++++++++++++++++++++++++ 3 files changed, 274 insertions(+), 28 deletions(-) create mode 100644 conductor/providers/cursor_agent.py diff --git a/conductor/conductor.py b/conductor/conductor.py index 35ac73219..fd190275e 100644 --- a/conductor/conductor.py +++ b/conductor/conductor.py @@ -223,6 +223,13 @@ def main(): default="low", help="Reasoning effort for models that support it (default: low).", ) + parser.add_argument( + "--provider", + type=str, + default="openrouter", + choices=["openrouter", "cursor"], + help="LLM provider (default: openrouter).", + ) parser.add_argument( "--kernel", type=str, @@ -266,16 +273,18 @@ def main(): if args.llm: from conductor.llm import run_scheduling_loop - from conductor.providers.openrouter import DEFAULT_MODEL - model = args.model or DEFAULT_MODEL - print(f"Running LLM scheduling loop (model={model})...", file=sys.stderr) + print( + f"Running LLM scheduling loop (provider={args.provider})...", + file=sys.stderr, + ) result = run_scheduling_loop( conductor, max_rounds=args.max_rounds, - model=model, + model=args.model, temperature=args.temperature, reasoning_effort=args.reasoning_effort, + provider=args.provider, ) print("\n=== LLM Scheduling Result ===") print(f" rounds: {result['rounds']}") diff --git a/conductor/llm.py b/conductor/llm.py index 8d524c488..cc649eb2d 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -4,12 +4,14 @@ import json import sys from collections.abc import Callable +from contextlib import nullcontext from conductor.providers.openrouter import ( DEFAULT_MODEL, + Counters, Message, Stats, - chat, + chat as openrouter_chat, ) from conductor.tools import Param, ToolRegistry @@ -99,12 +101,24 @@ def format_initial_prompt( return "\n".join(parts) +_NUDGE_NATIVE = ( + "You must use the evaluate_moves tool to test " + "scheduling ideas, or call done when finished. " + "Do not write tool calls as text." +) +_NUDGE_TEXT = ( + "You must output a ```json block to evaluate moves or call done. " + "Use the format described in the OUTPUT FORMAT section." +) + + def run_scheduling_loop( conductor, max_rounds: int = 10, - model: str = DEFAULT_MODEL, + model: str | None = None, temperature: float = 0.7, reasoning_effort: str | None = "high", + provider: str = "openrouter", log: Callable[[str], None] = _default_log, ) -> dict: """ @@ -114,8 +128,27 @@ def run_scheduling_loop( scheduling ideas. Conversation history (including reasoning) is preserved across rounds. + Args: + provider: "openrouter" (default) or "cursor". + Returns dict with keys: metrics, commands, rounds, baseline_metrics, usage. """ + # Provider setup. + if provider == "cursor": + from conductor.providers import cursor_agent + + chat_fn = cursor_agent.chat + cursor_agent.reset() + model = model or cursor_agent.DEFAULT_MODEL + system_prompt = SYSTEM_PROMPT + cursor_agent.TOOL_CALL_FORMAT + nudge_msg = _NUDGE_TEXT + else: + chat_fn = openrouter_chat + model = model or DEFAULT_MODEL + system_prompt = SYSTEM_PROMPT + nudge_msg = _NUDGE_NATIVE + + log(f"Provider: {provider}, model: {model}\n") log("Computing baseline metrics...\n") baseline = conductor.baseline() log(f" baseline: {baseline}\n") @@ -210,7 +243,7 @@ def _done(summary: str) -> str: ) messages: list[Message] = [ - {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "system", "content": system_prompt}, { "role": "user", "content": format_initial_prompt( @@ -219,16 +252,19 @@ def _done(summary: str) -> str: }, ] - with Stats() as stats: + use_native_tools = provider != "cursor" + stats_ctx = Stats() if use_native_tools else nullcontext(None) + + with stats_ctx as stats: for round_num in range(1, max_rounds + 1): log(f"\n--- Round {round_num}/{max_rounds} ---\n") - response = chat( + response = chat_fn( messages, model=model, temperature=temperature, reasoning_effort=reasoning_effort, - tools=registry.definitions(), + tools=registry.definitions() if use_native_tools else None, log=log, ) messages.append(response) @@ -239,19 +275,8 @@ def _done(summary: str) -> str: tool_calls = response.get("tool_calls") if not tool_calls: - # Model should always call a tool (evaluate_moves or done). - # If it didn't, nudge it to use the proper tool interface. log(" [retry] No tool call, nudging model...\n") - messages.append( - { - "role": "user", - "content": ( - "You must use the evaluate_moves tool to test " - "scheduling ideas, or call done when finished. " - "Do not write tool calls as text." - ), - } - ) + messages.append({"role": "user", "content": nudge_msg}) continue for tc in tool_calls: @@ -270,12 +295,13 @@ def _done(summary: str) -> str: if finished: break - usage = stats.counters - log( - f"\n=== Usage ===\n" - f" tokens: {usage.tokens} (in={usage.input_tokens} out={usage.output_tokens})\n" - f" cost: ${usage.cost:.4f}\n" - ) + usage: Counters | None = stats.counters if stats else None + if usage: + log( + f"\n=== Usage ===\n" + f" tokens: {usage.tokens} (in={usage.input_tokens} out={usage.output_tokens})\n" + f" cost: ${usage.cost:.4f}\n" + ) ir_diff = _context_diff(initial_ir, current_ir) if ir_diff: diff --git a/conductor/providers/cursor_agent.py b/conductor/providers/cursor_agent.py new file mode 100644 index 000000000..9caf067b7 --- /dev/null +++ b/conductor/providers/cursor_agent.py @@ -0,0 +1,211 @@ +"""Cursor Agent CLI provider for Conductor. + +Uses ``cursor-agent --print`` as a local LLM backend. Multi-turn +conversations are maintained via ``--resume SESSION_ID``. + +The model does not receive native tool schemas. Instead, the system +prompt instructs it to output fenced JSON blocks which we parse into +synthetic ``tool_calls`` compatible with the OpenRouter provider +interface. +""" + +import json +import re +import shutil +import subprocess +from collections.abc import Callable +from typing import Any + +Message = dict[str, Any] +DEFAULT_MODEL = "sonnet-4.6" + +# Appended to the system prompt so the model outputs parseable JSON +# instead of relying on native function-calling. +TOOL_CALL_FORMAT = """ + +OUTPUT FORMAT — you do NOT have native function-calling tools. +Instead, to invoke a tool output EXACTLY ONE fenced JSON block per response: + +To evaluate moves: +```json +{"action": "evaluate_moves", "moves": ["move TAG_A after TAG_B"], "summary": "brief reason"} +``` + +To finish: +```json +{"action": "done", "summary": "what you tried and the outcome"} +``` + +Rules: +- One JSON block per response, then STOP and wait for the result. +- Do NOT use Shell, Read, Write, Grep, or any other built-in tools. +- Do NOT output additional commentary after the JSON block. +""" + + +def _find_binary() -> str: + """Locate the cursor-agent binary.""" + path = shutil.which("cursor-agent") + if path: + return path + raise FileNotFoundError( + "cursor-agent not found in PATH. " + "Install: curl https://cursor.com/install -fsS | bash" + ) + + +def _build_text(messages: list[Message], start: int) -> str: + """Format messages[start:] as plain text for cursor-agent. + + Assistant messages are skipped (already in the session history). + Tool results are prefixed with ``[Tool Result]``. + """ + parts: list[str] = [] + for msg in messages[start:]: + role = msg["role"] + content = msg.get("content", "") + if role == "assistant": + continue + if role == "tool": + parts.append(f"[Tool Result]\n{content}") + else: + parts.append(content) + return "\n\n".join(parts) + + +def _parse_tool_calls(text: str) -> list[dict] | None: + """Extract tool calls from fenced JSON blocks in model output.""" + # Fenced ```json ... ``` blocks. + pattern = r"```(?:json)?\s*\n(\{[^`]*?\"action\"[^`]*?\})\s*\n```" + matches = re.findall(pattern, text, re.DOTALL) + if not matches: + # Bare JSON on its own line. + matches = re.findall(r'^(\{"action"\s*:.+\})$', text, re.MULTILINE) + tool_calls: list[dict] = [] + for i, raw in enumerate(matches): + try: + obj = json.loads(raw) + except json.JSONDecodeError: + continue + action = obj.pop("action", None) + if not action: + continue + tool_calls.append( + { + "id": f"cursor_{i}", + "type": "function", + "function": { + "name": action, + "arguments": json.dumps(obj), + }, + } + ) + return tool_calls or None + + +def _run( + prompt: str, + model: str, + session_id: str | None, + log: Callable[[str], None], +) -> tuple[str, str]: + """Invoke cursor-agent in headless mode. Returns (session_id, content).""" + cmd = [ + _find_binary(), + "--print", + "--output-format", + "stream-json", + "--model", + model, + "--force", + "--trust", + ] + if session_id: + cmd.extend(["--resume", session_id]) + cmd.append(prompt) + + log( + f" [cmd] cursor-agent --model {model} {'--resume ' + session_id if session_id else '--new'}\n" + ) + proc = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + + sid = session_id or "" + content = "" + for line in proc.stdout.strip().split("\n"): + line = line.strip() + if not line: + continue + try: + ev = json.loads(line) + except json.JSONDecodeError: + continue + etype = ev.get("type") + if etype == "system" and ev.get("subtype") == "init": + sid = ev.get("session_id", sid) + elif etype == "assistant": + # Last assistant message wins. + for block in ev.get("message", {}).get("content", []): + if isinstance(block, dict): + content = block.get("text", content) + elif isinstance(block, str): + content = block + elif etype == "result": + if event_is_error(ev): + log(f" [cursor-agent error] {ev.get('result', '')}\n") + if not content: + content = ev.get("result", "") + return sid, content + + +def event_is_error(ev: dict) -> bool: + """Check if a stream-json event signals an error.""" + return ev.get("is_error", False) or ev.get("subtype") == "error" + + +# ---- Module-level session state ---- + +_session_id: str | None = None +_sent_count: int = 0 + + +def reset() -> None: + """Reset session state for a new scheduling loop.""" + global _session_id, _sent_count + _session_id = None + _sent_count = 0 + + +def chat( + messages: list[Message], + model: str, + temperature: float = 0.7, + max_tokens: int = 2048, + reasoning_effort: str | None = None, + tools: list[dict] | None = None, + log: Callable[[str], None] | None = None, +) -> Message: + """Send a chat turn via cursor-agent CLI. + + Maintains session state across calls. Tool calls are extracted from + the model's text output (fenced JSON blocks). + """ + global _session_id, _sent_count + if log is None: + log = lambda _: None + + prompt = _build_text(messages, _sent_count) + _sent_count = len(messages) + + sid, content = _run(prompt, model, _session_id, log) + _session_id = sid + + if content: + preview = content[:300] + "..." if len(content) > 300 else content + log(f" [response] {preview}\n") + + tool_calls = _parse_tool_calls(content) + result: Message = {"role": "assistant", "content": content} + if tool_calls: + result["tool_calls"] = tool_calls + + return result From 434c3df0981036236bbf6acf8d75f61d5a7c7b47 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 20:42:03 +0100 Subject: [PATCH 46/49] sessions refac Signed-off-by: Ivan Butygin --- conductor/llm.py | 4 +++- conductor/providers/cursor_agent.py | 34 ++++++++++++++--------------- conductor/providers/openrouter.py | 1 + 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index cc649eb2d..157b1ab9e 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -138,7 +138,6 @@ def run_scheduling_loop( from conductor.providers import cursor_agent chat_fn = cursor_agent.chat - cursor_agent.reset() model = model or cursor_agent.DEFAULT_MODEL system_prompt = SYSTEM_PROMPT + cursor_agent.TOOL_CALL_FORMAT nudge_msg = _NUDGE_TEXT @@ -254,6 +253,7 @@ def _done(summary: str) -> str: use_native_tools = provider != "cursor" stats_ctx = Stats() if use_native_tools else nullcontext(None) + session = None # Opaque handle with stats_ctx as stats: for round_num in range(1, max_rounds + 1): @@ -266,7 +266,9 @@ def _done(summary: str) -> str: reasoning_effort=reasoning_effort, tools=registry.definitions() if use_native_tools else None, log=log, + session=session, ) + session = response.pop("session", session) messages.append(response) content = response.get("content", "") diff --git a/conductor/providers/cursor_agent.py b/conductor/providers/cursor_agent.py index 9caf067b7..137abcfa6 100644 --- a/conductor/providers/cursor_agent.py +++ b/conductor/providers/cursor_agent.py @@ -14,6 +14,7 @@ import shutil import subprocess from collections.abc import Callable +from dataclasses import dataclass from typing import Any Message = dict[str, Any] @@ -162,17 +163,12 @@ def event_is_error(ev: dict) -> bool: return ev.get("is_error", False) or ev.get("subtype") == "error" -# ---- Module-level session state ---- +@dataclass +class Session: + """Opaque session handle passed between caller and provider.""" -_session_id: str | None = None -_sent_count: int = 0 - - -def reset() -> None: - """Reset session state for a new scheduling loop.""" - global _session_id, _sent_count - _session_id = None - _sent_count = 0 + id: str | None = None + sent: int = 0 def chat( @@ -183,21 +179,24 @@ def chat( reasoning_effort: str | None = None, tools: list[dict] | None = None, log: Callable[[str], None] | None = None, + session: Session | None = None, ) -> Message: """Send a chat turn via cursor-agent CLI. - Maintains session state across calls. Tool calls are extracted from - the model's text output (fenced JSON blocks). + The caller owns the ``Session`` object and passes it back on each + call. Tool calls are extracted from fenced JSON blocks in the + model's text output. """ - global _session_id, _sent_count if log is None: log = lambda _: None + if session is None: + session = Session() - prompt = _build_text(messages, _sent_count) - _sent_count = len(messages) + prompt = _build_text(messages, session.sent) - sid, content = _run(prompt, model, _session_id, log) - _session_id = sid + sid, content = _run(prompt, model, session.id, log) + session.id = sid + session.sent = len(messages) if content: preview = content[:300] + "..." if len(content) > 300 else content @@ -207,5 +206,6 @@ def chat( result: Message = {"role": "assistant", "content": content} if tool_calls: result["tool_calls"] = tool_calls + result["session"] = session return result diff --git a/conductor/providers/openrouter.py b/conductor/providers/openrouter.py index 76751d3e8..b43a703d3 100644 --- a/conductor/providers/openrouter.py +++ b/conductor/providers/openrouter.py @@ -254,6 +254,7 @@ def chat( reasoning_effort: str | None = None, tools: list[dict] | None = None, log: Callable[[str], None] | None = None, + session: Any = None, ) -> Message: """Send a chat completion request. Returns the full response message dict. From 196ac79b94c8b72ab5cd0f7e807b705a75991769 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 22 Feb 2026 20:46:28 +0100 Subject: [PATCH 47/49] refac default model Signed-off-by: Ivan Butygin --- conductor/llm.py | 3 --- conductor/providers/cursor_agent.py | 3 ++- conductor/providers/openrouter.py | 3 ++- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/conductor/llm.py b/conductor/llm.py index 157b1ab9e..8501e34f4 100644 --- a/conductor/llm.py +++ b/conductor/llm.py @@ -7,7 +7,6 @@ from contextlib import nullcontext from conductor.providers.openrouter import ( - DEFAULT_MODEL, Counters, Message, Stats, @@ -138,12 +137,10 @@ def run_scheduling_loop( from conductor.providers import cursor_agent chat_fn = cursor_agent.chat - model = model or cursor_agent.DEFAULT_MODEL system_prompt = SYSTEM_PROMPT + cursor_agent.TOOL_CALL_FORMAT nudge_msg = _NUDGE_TEXT else: chat_fn = openrouter_chat - model = model or DEFAULT_MODEL system_prompt = SYSTEM_PROMPT nudge_msg = _NUDGE_NATIVE diff --git a/conductor/providers/cursor_agent.py b/conductor/providers/cursor_agent.py index 137abcfa6..ffb28b680 100644 --- a/conductor/providers/cursor_agent.py +++ b/conductor/providers/cursor_agent.py @@ -173,7 +173,7 @@ class Session: def chat( messages: list[Message], - model: str, + model: str | None = None, temperature: float = 0.7, max_tokens: int = 2048, reasoning_effort: str | None = None, @@ -187,6 +187,7 @@ def chat( call. Tool calls are extracted from fenced JSON blocks in the model's text output. """ + model = model or DEFAULT_MODEL if log is None: log = lambda _: None if session is None: diff --git a/conductor/providers/openrouter.py b/conductor/providers/openrouter.py index b43a703d3..62ae61bf8 100644 --- a/conductor/providers/openrouter.py +++ b/conductor/providers/openrouter.py @@ -248,7 +248,7 @@ def _stream_request( def chat( messages: list[Message], - model: str, + model: str | None = None, temperature: float = 0.7, max_tokens: int = 2048, reasoning_effort: str | None = None, @@ -261,6 +261,7 @@ def chat( Handles payload construction, retries on transient errors, usage recording, and streaming log output. """ + model = model or DEFAULT_MODEL if not API_KEY: raise RuntimeError( "OPENROUTER_API_KEY not set. Export it before running the LLM loop." From c736ade36e92ff41177b5ff650b9e01db15269fe Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 23 Feb 2026 16:57:16 +0100 Subject: [PATCH 48/49] conductor fix Signed-off-by: Ivan Butygin --- .../wave_asm/test/Transforms/apply-moves-after.mlir | 6 +++--- .../test/Transforms/apply-moves-error-dominance.mlir | 6 +++--- .../test/Transforms/apply-moves-error-pinned.mlir | 4 ++-- .../wave_asm/test/Transforms/apply-moves-swap.mlir | 8 ++++---- .../asm/wave_asm/test/Transforms/apply-moves.mlir | 6 +++--- .../tools/waveasm-conductor/waveasm-conductor.cpp | 12 +----------- 6 files changed, 16 insertions(+), 26 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir index 555dec184..7712080ee 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-after.mlir @@ -12,8 +12,8 @@ waveasm.program @test_move_after target = #waveasm.target<#waveasm.gfx942, 5> ab %c4 = waveasm.constant 4 : !waveasm.imm<4> %c1 = waveasm.constant 1 : !waveasm.imm<1> - %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - %a1 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + %a1 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg loc("v_add_u32_1") - waveasm.s_endpgm + waveasm.s_endpgm loc("s_endpgm_0") } diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir index 7eaa9a751..da6d452bc 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-dominance.mlir @@ -11,8 +11,8 @@ waveasm.program @test_dominance target = #waveasm.target<#waveasm.gfx942, 5> abi %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> %c4 = waveasm.constant 4 : !waveasm.imm<4> - %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + %a1 = waveasm.v_add_u32 %a0, %c4 : !waveasm.vreg, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_1") - waveasm.s_endpgm + waveasm.s_endpgm loc("s_endpgm_0") } diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir index d62fafb5d..d0620cb6e 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-error-pinned.mlir @@ -8,6 +8,6 @@ waveasm.program @test_pinned target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> %c4 = waveasm.constant 4 : !waveasm.imm<4> - %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - waveasm.s_endpgm + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + waveasm.s_endpgm loc("s_endpgm_0") } diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir index f56ef1936..fb54ecd2f 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves-swap.mlir @@ -12,9 +12,9 @@ waveasm.program @test_swap target = #waveasm.target<#waveasm.gfx942, 5> abi = #w %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> %c4 = waveasm.constant 4 : !waveasm.imm<4> - %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - %a1 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - %s0 = waveasm.v_lshlrev_b32 %c4, %v0 : !waveasm.imm<4>, !waveasm.pvreg<0> -> !waveasm.vreg + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + %a1 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_1") + %s0 = waveasm.v_lshlrev_b32 %c4, %v0 : !waveasm.imm<4>, !waveasm.pvreg<0> -> !waveasm.vreg loc("v_lshlrev_b32_0") - waveasm.s_endpgm + waveasm.s_endpgm loc("s_endpgm_0") } diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir index a5818db1f..d38773df1 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/apply-moves.mlir @@ -12,8 +12,8 @@ waveasm.program @test_move_before target = #waveasm.target<#waveasm.gfx942, 5> a %c4 = waveasm.constant 4 : !waveasm.imm<4> %c1 = waveasm.constant 1 : !waveasm.imm<1> - %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg - %a1 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg + %a0 = waveasm.v_add_u32 %v0, %c4 : !waveasm.pvreg<0>, !waveasm.imm<4> -> !waveasm.vreg loc("v_add_u32_0") + %a1 = waveasm.v_add_u32 %v0, %c1 : !waveasm.pvreg<0>, !waveasm.imm<1> -> !waveasm.vreg loc("v_add_u32_1") - waveasm.s_endpgm + waveasm.s_endpgm loc("s_endpgm_0") } diff --git a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp index 9c7da35a8..8c0d6ccf2 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/tools/waveasm-conductor/waveasm-conductor.cpp @@ -13,7 +13,6 @@ #include "waveasm/Dialect/WaveASMDialect.h" #include "waveasm/Transforms/ApplyMoves.h" -#include "waveasm/Transforms/Passes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" @@ -21,7 +20,6 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" -#include "mlir/Pass/PassManager.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" @@ -103,15 +101,7 @@ int main(int argc, char **argv) { return 1; } - // Run tag-instructions pass to attach NameLoc tags. - PassManager pm(&context); - pm.addPass(waveasm::createWAVEASMTagInstructionsPass()); - if (failed(pm.run(*module))) { - llvm::errs() << "Tag-instructions pass failed\n"; - return 1; - } - - // Apply the move commands. + // Apply the move commands (IR is expected to have NameLoc tags already). waveasm::MoveResult result = waveasm::applyMoves(*module, parseResult.commands); if (!result.success) { From 748862c1188090e3e02adb56d7da374129901a66 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 23 Feb 2026 17:51:55 +0100 Subject: [PATCH 49/49] pass agent text through stdin Signed-off-by: Ivan Butygin --- conductor/providers/cursor_agent.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/conductor/providers/cursor_agent.py b/conductor/providers/cursor_agent.py index ffb28b680..0ce879eb9 100644 --- a/conductor/providers/cursor_agent.py +++ b/conductor/providers/cursor_agent.py @@ -123,12 +123,14 @@ def _run( ] if session_id: cmd.extend(["--resume", session_id]) - cmd.append(prompt) log( - f" [cmd] cursor-agent --model {model} {'--resume ' + session_id if session_id else '--new'}\n" + f" [cmd] cursor-agent --model {model} {'--resume ' + session_id if session_id else '--new'} ({len(prompt)} chars)\n" + ) + # Pipe prompt via stdin to avoid OS argument length limits. + proc = subprocess.run( + cmd, input=prompt, capture_output=True, text=True, timeout=300 ) - proc = subprocess.run(cmd, capture_output=True, text=True, timeout=300) sid = session_id or "" content = ""