Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3cb0568
draft
Hardcode84 Feb 21, 2026
6b4ee26
update
Hardcode84 Feb 21, 2026
7f5e72f
parallelism
Hardcode84 Feb 21, 2026
7d143a9
warnings
Hardcode84 Feb 21, 2026
36c0414
fix gitignore
Hardcode84 Feb 21, 2026
5bbc360
tag instructions pass
Hardcode84 Feb 21, 2026
b2f5f1f
test
Hardcode84 Feb 21, 2026
fd89258
conductor
Hardcode84 Feb 21, 2026
537e3e9
update lit
Hardcode84 Feb 21, 2026
e9f40d3
update parsing
Hardcode84 Feb 21, 2026
6b72fa0
remove "done"
Hardcode84 Feb 21, 2026
978c126
verify in conductor
Hardcode84 Feb 21, 2026
a11743b
conductor
Hardcode84 Feb 21, 2026
39d38b5
conductor LLM integration
Hardcode84 Feb 21, 2026
565d2f0
model and remove done
Hardcode84 Feb 21, 2026
d4ef635
logging
Hardcode84 Feb 22, 2026
7c5fd97
prompt
Hardcode84 Feb 22, 2026
a420f18
tool use
Hardcode84 Feb 22, 2026
8fbe77e
usage
Hardcode84 Feb 22, 2026
dd4e3b2
refac
Hardcode84 Feb 22, 2026
a2ff94b
_with_retry
Hardcode84 Feb 22, 2026
3efee7d
dome tool
Hardcode84 Feb 22, 2026
7759d6b
abstract tools
Hardcode84 Feb 22, 2026
f49d4aa
print the updated IR
Hardcode84 Feb 22, 2026
6c6fa57
print ir
Hardcode84 Feb 22, 2026
6871a7f
prompt tweak
Hardcode84 Feb 22, 2026
716966e
error logging
Hardcode84 Feb 22, 2026
7c2babd
scheduling
Hardcode84 Feb 22, 2026
b38b793
mxfp kernel
Hardcode84 Feb 22, 2026
fa2f3c6
fix rocm check
Hardcode84 Feb 22, 2026
ca3da4a
instructions
Hardcode84 Feb 22, 2026
c802434
show diff
Hardcode84 Feb 22, 2026
5994e92
move doc
Hardcode84 Feb 22, 2026
5f96075
beter kernel
Hardcode84 Feb 22, 2026
7c935de
less logs
Hardcode84 Feb 22, 2026
f0febd7
clang-format
Hardcode84 Feb 22, 2026
8885ad0
prompt and less code
Hardcode84 Feb 22, 2026
a49ab7a
actually change default efforts
Hardcode84 Feb 22, 2026
df7e2f0
proper dominance report
Hardcode84 Feb 22, 2026
5db5160
stateful
Hardcode84 Feb 22, 2026
02f0406
ssa names
Hardcode84 Feb 22, 2026
ec9e39d
print diff
Hardcode84 Feb 22, 2026
7b0c013
eveluate_moved summary
Hardcode84 Feb 22, 2026
4088c7c
move openrouter to providers
Hardcode84 Feb 22, 2026
60eb6e7
cursor agent as provider
Hardcode84 Feb 22, 2026
434c3df
sessions refac
Hardcode84 Feb 22, 2026
196ac79
refac default model
Hardcode84 Feb 22, 2026
c736ade
conductor fix
Hardcode84 Feb 23, 2026
748862c
pass agent text through stdin
Hardcode84 Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
421 changes: 421 additions & 0 deletions conductor/CONDUCTOR_DESIGN.md

Large diffs are not rendered by default.

117 changes: 117 additions & 0 deletions conductor/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Conductor

LLM-guided instruction scheduling for WaveASM. Takes pre-scheduling WaveASM IR,
lets an LLM reorder instructions via move/swap commands, then compiles and
measures the result. See
[CONDUCTOR_DESIGN.md](../wave_lang/kernel/wave/asm/wave_asm/include/waveasm/CONDUCTOR_DESIGN.md)
for the full design rationale.

## Pipeline

```
Python frontend (GEMM kernel)
→ Wave compiler → MLIR
→ waveasm-translate (pre-scheduling: CSE, peephole, mem-offset-opt)
→ Tag instructions (attach loc("tag_name") to each op)
→ LLM loop: present IR + metrics → get move commands → apply → compile → measure
→ Final assembly
```

## Modules

| File | Purpose |
|------|---------|
| `extract_ir.py` | Captures MLIR from a Wave kernel, runs pre-scheduling pipeline, computes baseline metrics. |
| `conductor.py` | `Conductor` class: tag, apply_moves, compile_to_asm, evaluate, baseline. CLI entry point. |
| `llm.py` | OpenRouter client, prompt formatting, command parsing, iterative scheduling loop. |

## Prerequisites

- `waveasm-translate` and `waveasm-conductor` binaries (build from `wave_lang/kernel/wave/asm/wave_asm/`).
- Python `requests` package.
- `OPENROUTER_API_KEY` env var for LLM mode.

## Usage

All commands should be run from the wave project root with `WAVE_CACHE_ON=0`.

### Baseline metrics (no LLM)

```bash
python -m conductor.conductor --metrics
```

### Show tagged IR

```bash
python -m conductor.conductor --tag-only
```

### Apply manual moves

```bash
python -m conductor.conductor \
--moves "swap buffer_load_dwordx4_0 v_mfma_f32_16x16x16_f16_0" \
--metrics
```

### Read moves from a file

```bash
python -m conductor.conductor --moves-file moves.txt --metrics
```

### LLM-guided scheduling

```bash
OPENROUTER_API_KEY=sk-... python -m conductor.conductor \
--llm --max-rounds 5 --model google/gemini-2.5-flash-preview --metrics
```

### Extract pre-scheduling IR only

```bash
python -m conductor.extract_ir -o /tmp/conductor_ir.mlir
python -m conductor.extract_ir --metrics
```

## Move commands

| Command | Example |
|---------|---------|
| `move TAG before TAG` | `move v_add_u32_1 before v_add_u32_0` |
| `move TAG after TAG` | `move buffer_load_dwordx4_0 after ds_read_b128_2` |
| `swap TAG TAG` | `swap v_mfma_f32_16x16x16_f16_0 ds_read_b128_1` |

## Environment variables

| Variable | Default | Purpose |
|----------|---------|---------|
| `OPENROUTER_API_KEY` | (none) | Required for `--llm` mode. |
| `WAVE_CACHE_ON` | `1` | Set to `0` to disable caching during development. |
| `WAVE_DEFAULT_ARCH` | `gfx942` | Target GPU architecture. |
| `WAVEASM_TRANSLATE` | (auto-detect) | Override path to `waveasm-translate` binary. |
| `WAVEASM_CONDUCTOR` | (auto-detect) | Override path to `waveasm-conductor` binary. |

## Programmatic usage

```python
from conductor import Conductor
from conductor.extract_ir import capture_kernel_mlir, run_pre_scheduling_pipeline
from conductor.llm import run_scheduling_loop

mlir_text, wg_size = capture_kernel_mlir()
waveasm_ir = run_pre_scheduling_pipeline(mlir_text, wg_size)

c = Conductor(waveasm_ir, wg_size)

# Baseline.
print(c.baseline())

# Manual moves.
print(c.evaluate(["swap buffer_load_dwordx4_0 v_mfma_f32_16x16x16_f16_0"]))

# LLM loop.
result = run_scheduling_loop(c, max_rounds=5, model="google/gemini-2.5-flash-preview")
print(result["metrics"], result["commands"])
```
182 changes: 182 additions & 0 deletions conductor/SCHEDULING_GUIDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# CDNA4 Instruction Scheduling Guide

Reference for the Conductor LLM scheduling loop. Based on the AMD Instinct
CDNA4 ISA Reference Guide (gfx950).

## Architecture Overview

- Wave size: 64 threads.
- Register files: 256 arch VGPRs (V0-V255) + 256 AccVGPRs/AGPRs (A0-A255) = 512 total per wave.
- SGPRs: 16-102 per wave (VCC occupies SGPR 106-107).
- LDS: 160 KiB per CU, 64 banks x 32-bit, 1280-byte allocation granularity.
- Allocation granularity: VGPRs in groups of 8, SGPRs in groups of 16.
- Issue model: one instruction per cycle per wavefront (latency hiding via interleaving wavefronts).

## Instruction Latencies

| Instruction class | Latency (cycles) | Counter |
|---|---|---|
| Global load (buffer_load, global_load) | ~100 | vmcnt |
| LDS read (ds_read) | ~20 | lgkmcnt |
| LDS write (ds_write) | ~20 | lgkmcnt |
| MFMA F16/BF16 16x16x16 | 16 | — |
| MFMA F16/BF16 32x32x8 | 32 | — |
| MFMA F32 16x16x4 | 32 | — |
| MFMA F32 32x32x2 | 64 | — |
| MFMA F8/F4 16x16x128 | 16 (FP4/FP6) or 32 (FP8) | — |
| MFMA F8/F4 32x32x64 | 32 (FP4/FP6) or 64 (FP8) | — |
| scaled_mfma (MXFP4) | Same as F8/F4 above | — |
| MFMA F64 16x16x4 | 64 | — |
| VALU (non-transcendental) | 1 | — |
| VALU transcendental (exp, log, rcp, rsq, sqrt, sin, cos) | 2 | — |
| SALU | 1 | — |
| Scalar memory read | variable | lgkmcnt |

## Waitcnt Counters

| Counter | Bits | Max outstanding | Tracked operations |
|---|---|---|---|
| vmcnt | 6 | 63 | Global/buffer loads and stores |
| lgkmcnt | 4 | 15 | LDS ops, scalar memory, sendmsg |

- VMEM reads/writes return in issue order.
- Scalar memory reads can return **out of order** — only `lgkmcnt(0)` is safe for SMEM.
- FLAT instructions increment both vmcnt and lgkmcnt — only `s_waitcnt 0` is safe after FLAT.
- `s_endpgm` implicitly executes `s_waitcnt 0`.

## MFMA Dependency Rules

### Accumulator chains (SrcC = previous vDst, same opcode, same register range)

**Zero software NOPs required.** The hardware provides 2 implicit wait cycles.
This is the intended use pattern for matrix accumulation loops — chain MFMAs
back-to-back on the same accumulator with no stall.

### Cross-dependencies (reading MFMA output as SrcA/SrcB or in VALU)

These require the full output latency to elapse:

| Scenario | Wait (NOPs) |
|---|---|
| XDL write → SrcA/B of any MFMA | 5 / 8 / 12 / 20 |
| XDL write → VALU read/write (RAW+WAW) | 5 / 8 / 12 / 20 |
| XDL write → VMEM/LDS/FLAT overlap | 5 / 8 / 12 / 20 |
| SGEMM write → SrcA/B of any MFMA | 4 / 6 / 10 / 18 |
| DGEMM 16x16x4 write → SrcA/B or VALU | 19 |

The multiple values correspond to different output register overlap depths.

### Cross-type forwarding

| Scenario | Wait (NOPs) |
|---|---|
| SGEMM write → XDL SrcC (overlapping) | **0** (XDL reads SrcC 2x faster) |
| XDL write → SGEMM SrcC (overlapping) | 3 |
| v_cmpx (writes EXEC) → any V_MFMA | **4** (no exec forwarding to matrix core) |

## Software Hazards (Table 11)

The hardware does **not** detect these. The compiler must insert s_nop or
independent instructions.

| First instruction | Second instruction | NOPs required |
|---|---|---|
| VALU writes SGPR | VMEM reads that SGPR | **5** |
| VALU sets VCC or EXEC | VALU reads EXECZ/VCCZ as data | **5** |
| VALU writes EXEC | VALU DPP op | **5** |
| VALU writes SGPR/VCC | v_readlane/v_writelane lane select | **4** |
| VALU writes VCC | v_div_fmas | **4** |
| v_cmpx writes EXEC | v_readlane/v_readfirstlane/v_writelane | **4** |
| VALU writes SGPR/VCC | VALU reads SGPR as constant | **2** |
| v_cmpx writes EXEC | VALU reads EXEC as constant | **2** |
| S_SETREG MODE.vskip | Any vector op | **2** |
| Trans op (exp, log, rcp, ...) | Non-trans VALU consuming result | **1** |
| VALU writes VGPR | VALU DPP reads that VGPR | **1** |
| SALU writes M0 | LDS add-TID / buffer_store_LDS | **1** |
| Mixed VCC alias access | VALU reads VCC as constant | **1** |
| FLAT/BUFFER_STORE x3/x4 | Write VGPRs holding writedata | **1** |

Note: `s_nop N` inserts N+1 idle cycles. So `s_nop 4` = 5 NOPs.

## Occupancy and Register Pressure

Waves per SIMD as a function of VGPR usage (512-slot pool, 8-register granularity):

| VGPRs used | Waves/SIMD | Waves/CU (4 SIMDs) |
|---|---|---|
| ≤64 | 8 | 32 |
| 65-72 | 7 | 28 |
| 73-80 | 6 | 24 |
| 81-96 | 5 | 20 |
| 97-128 | 4 | 16 |
| 129-168 | 3 | 12 |
| 169-256 | 2 | 8 |

AGPRs use an independent pool with the same formula. Final occupancy is
`min(vgpr_waves, agpr_waves, sgpr_waves, lds_waves)`.

**Key breakpoints:** Dropping below 128, 96, 80, or 64 VGPRs each adds one
wave per SIMD. For MFMA-dominant kernels, 2-4 waves/SIMD is often optimal.

## Scheduling Strategy

### 1. Issue global loads early

Global loads have ~100 cycle latency. Move them as far above their consumers
as possible. Every MFMA (16-64 cycles) or LDS op placed between the load and
its `s_waitcnt vmcnt` hides latency for free.

### 2. Interleave LDS reads with MFMA

LDS reads take ~20 cycles. A single MFMA 16x16x16 (16 cycles) nearly covers
one LDS read. Interleaving `ds_read → mfma → ds_read → mfma` hides LDS
latency much better than `ds_read, ds_read, ..., mfma, mfma, ...`.

### 3. Fill MFMA pipeline bubbles

Between two MFMAs with an accumulator dependency (B.SrcC = A.vDst), the
hardware stalls until A completes. Fill this gap with:
- Independent LDS reads/writes (for the next iteration).
- Independent VALU (address computation, data formatting).
- Independent global loads (prefetch for next tile).

### 4. Minimize live ranges

Moving a producer closer to its consumer (or deferring a def until just
before use) reduces the number of simultaneously live VGPRs, potentially
lowering peak register pressure and enabling higher occupancy.

### 5. Double-buffering pattern

The ideal software-pipelined loop body:
```
barrier
ds_read current tile from LDS
global_load next tile (prefetch)
mfma on current tile (hides global_load latency)
s_waitcnt vmcnt(0)
barrier
ds_write next tile to LDS
```

### 6. Metric priorities (lexicographic)

1. **peak_vgpr** — lower = higher occupancy.
2. **s_waitcnt** — fewer waits = better latency hiding.
3. **s_nop** — fewer NOPs = fewer pipeline hazards.
4. **total_instructions** — fewer = less instruction cache pressure.

## LDS Details

- 160 KiB per CU, 64 banks, each 32-bit wide.
- Best case latency (no bank conflicts): 2 cycles.
- Worst case (64-way bank conflict): 64 cycles.
- A wavefront (64 threads) is dispatched over 4 sub-cycles (16 threads each).
- DS_READ2/DS_WRITE2 (64-bit extended) double per-op bandwidth.

## Barrier Semantics

- `s_barrier` synchronizes all wavefronts in a workgroup.
- It does NOT implicitly wait for memory — issue `s_waitcnt` before `s_barrier`
if protecting memory operations.
- Treat barriers as hard scheduling boundaries. Never move ops across barriers.
39 changes: 39 additions & 0 deletions conductor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2025 The Wave Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Conductor: LLM-guided instruction scheduling for WaveASM."""

from conductor.conductor import Conductor, find_waveasm_conductor
from conductor.extract_ir import (
find_waveasm_translate,
run_waveasm_translate,
run_pre_scheduling_pipeline,
run_full_pipeline,
count_asm_metrics,
capture_kernel_mlir,
capture_mxfp4_kernel_mlir,
)
from conductor.llm import run_scheduling_loop
from conductor.providers.openrouter import Stats, Counters
from conductor.tools import Param, ToolDef, ToolRegistry

__all__ = [
"Conductor",
"find_waveasm_conductor",
"find_waveasm_translate",
"run_waveasm_translate",
"run_pre_scheduling_pipeline",
"run_full_pipeline",
"count_asm_metrics",
"capture_kernel_mlir",
"capture_mxfp4_kernel_mlir",
"run_scheduling_loop",
"Stats",
"Counters",
"Param",
"ToolDef",
"ToolRegistry",
]
Loading