Conversation
|
Thanks for investigating the k-scale issue. I have a couple of points of queries before we begin to share test with a real example. First, do the files in Secondly, I could you please move the Could you also add I will also post a real example here with that you can try and experiment with to record benchmark data for the kernel with figures. cc: @florianmattana |
|
For the build method: I didn't use Nix. The reason is I wanted a setup that works with a standard Python venv + PyTorch, without requiring contributors to install Nix.(or myself since I noticed it was something like 30Go ?) Here's what I did: Installed build2cmake (a Rust tool) via cargo install build2cmake I chose this approach because it keeps the build pure PyTorch anyone with CUDA and PyTorch installed can build the kernel in one command without extra tooling. But if you'd prefer sticking with Nix or a different build system, I'm happy to adapt. What works best for you? Regarding test_simple.py, compat.py and metadata.jso did I push it on it any branch unintionally ? Let me know if it sounds ok to you or Ill adapt. |
Well, since we be using the kernel with the In my experience thus far, building with Nix is less than ideal since it requires setup with cachix, nix cli and a set of other issues, so if we are able to publish the fixed kernel with torch, I think that can work as well. If you can view the file diff for this PR, it should not be this large, so if you could either organize the files if you consider them to be useful, that would be quite helpful. |
|
Here's an example script with flux-schnell that benchmarks the performance of the kernel against torch SDPA attention, I've also implemented a Custom Image Processor based on the flux_transformer ,feel free to modify the script to either improve the processor, add visual figures to the benchmarks: import torch
import time
import numpy as np
from diffusers import DiffusionPipeline
from kernels import get_local_kernel
from pathlib import Path
# =====================================================
# CONFIG
# =====================================================
DEVICE = "cuda"
DTYPE = torch.bfloat16
PROMPT = "a futuristic city at sunset, ultra detailed"
NUM_STEPS = 8
RUNS = 4
WARMUP = 2
# =====================================================
# LOAD MODEL
# =====================================================
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=DTYPE
).to(DEVICE)
pipe.set_progress_bar_config(disable=True)
# =====================================================
# LOAD YOUR KERNEL
# =====================================================
repo_path = Path("./torch-ext/attention_int8")
kernel_mod = get_local_kernel(repo_path, "int8_attn")
kernel_fn = kernel_mod.int8_attention_forward
# =====================================================
# CUSTOM PROCESSOR
# =====================================================
class FluxInt8AttnProcessor:
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
image_rotary_emb=None,
):
# ---- projections ----
from diffusers.models.transformers.transformer_flux import _get_qkv_projections, apply_rotary_emb
query, key, value, eq, ek, ev = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
# ---- reshape ----
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
# ---- norm ----
query = attn.norm_q(query)
key = attn.norm_k(key)
# ---- joint attention ----
if attn.added_kv_proj_dim is not None:
eq = eq.unflatten(-1, (attn.heads, -1))
ek = ek.unflatten(-1, (attn.heads, -1))
ev = ev.unflatten(-1, (attn.heads, -1))
eq = attn.norm_added_q(eq)
ek = attn.norm_added_k(ek)
query = torch.cat([eq, query], dim=1)
key = torch.cat([ek, key], dim=1)
value = torch.cat([ev, value], dim=1)
# ---- rotary ----
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# ---- layout convert ----
query = query.permute(0, 2, 1, 3).contiguous()
key = key.permute(0, 2, 1, 3).contiguous()
value = value.permute(0, 2, 1, 3).contiguous()
# ---- kernel ----
out = self.kernel_fn(query, key, value)
# ---- back ----
out = out.permute(0, 2, 1, 3).contiguous()
out = out.flatten(2, 3)
# ---- split ----
if encoder_hidden_states is not None:
enc_len = encoder_hidden_states.shape[1]
enc, hid = out.split([enc_len, out.shape[1] - enc_len], dim=1)
hid = attn.to_out[0](hid)
hid = attn.to_out[1](hid)
enc = attn.to_add_out(enc)
return hid, enc
return out
# =====================================================
# BENCHMARK UTIL
# =====================================================
def benchmark(name, apply_processor=None):
if apply_processor:
pipe.transformer.set_attn_processor(apply_processor)
else:
pipe.transformer.set_default_attn_processor()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# warmup
for _ in range(WARMUP):
_ = pipe(PROMPT, num_inference_steps=NUM_STEPS).images
torch.cuda.synchronize()
times = []
for _ in range(RUNS):
start = time.time()
out = pipe(PROMPT, num_inference_steps=NUM_STEPS).images
torch.cuda.synchronize()
times.append(time.time() - start)
times = np.array(times)
mem = torch.cuda.max_memory_allocated() / 1e9
print(f"\n{name}")
print(f"Mean: {times.mean():.3f}s | P95: {np.percentile(times,95):.3f}s | Std: {times.std():.3f}s")
print(f"Peak memory: {mem:.2f} GB")
return out, times.mean()
# =====================================================
# RUN
# =====================================================
print("=== Running Benchmarks ===")
# SDPA baseline
out_sdpa, t_sdpa = benchmark("SDPA (baseline)", None)
# Custom kernel
processor = FluxInt8AttnProcessor(kernel_fn)
out_custom, t_custom = benchmark("Custom INT8 Kernel", processor)
# =====================================================
# ACCURACY CHECK
# =====================================================
def compare_images(img1, img2):
t1 = torch.tensor(np.array(img1)).float()
t2 = torch.tensor(np.array(img2)).float()
diff = (t1 - t2).abs()
return {
"mean": diff.mean().item(),
"max": diff.max().item()
}
metrics = compare_images(out_sdpa[0], out_custom[0])
print("\n=== Accuracy ===")
print(f"Mean pixel diff: {metrics['mean']:.4f}")
print(f"Max pixel diff : {metrics['max']:.4f}")If this script works as expected, we should be good enough to get figure-based benchmarks or try to benchmark against SageAttention as well. |
Hey Parag, I updated the PR. I removed the clang-format hook because that what the one that was reformatting the entire CUDA file and making the diff huge. The config now only has standard pre-commit hooks (trailing whitespace, end-of-file, yaml check) and ruff for Python/TOML. The diff is lesser now, just minor whitespace fixes. For the build to make you understand I used build2cmake to generate CMakeLists.txt from build.toml, then pip install --no-build-isolation -e . which runs CMake + nvcc and produces the .so. No Nix involved. I went this route to keep things simple for anyone with just CUDA and PyTorch installed. For the test files (test_simple.py, compat.py, metadata.json), I checked all branches and they were never pushed, they only existed on my previous machine. I can recreate them in a test_kernel/ directory if you want but I except if I am mistaken they are not needed anymore ? |
|
Well, I can still see the test files (test_simple.py, compat.py, metadata.json) in the file diff, could you please re-check and remove these files. We'll write about the build or setup process in a seperate
Would you mind resolving the merge conflicts in #8 and #10 due to the pre-commit changes which were pushed? |
Also, feel free to let me know you would like to go with a different model or take a different approach in the benchmarking script. |
|
@florianmattana Hey, how is the testing/benchmarking of the kernel going? Would you like any additional resources or help from me? |
Hi Parag, sorry I am very preoccupied these days. A lot of request from everywhere. I wont commit too much but Ill try before the end of the week. Beside testing did you need help with anything else ? |
|
Well, once we are able to benchmark the results, merge this PR and push the corrected kernel to HF Hub, I would like to write an article on Medium documenting all of the fixes you have applied and as a result, the performance gains we have achieved for diffusion models. You can view some of my previous technical content which I have posted: I will take care of the entire writing part, you would only need to provide me the benchmark numbers and some cool figures that we could use in it and optionally give feedback if you'd like. WDYT? cc: @florianmattana |
|
Ok ill try to do asap. Hopefully before the end of the week |
a367f84 to
577bdee
Compare
|
Hey Parag, here's what was done today:
For the benchmark: tried running it locally (RTX 5070 Ti, 12 GB VRAM) but FLUX doesn't fit in VRAM. Sequential CPU offload works but gives ~50s/inference which makes the numbers meaningless. Tested the kernel directly with synthetic tensors (batch=1, heads=24, seq=4096, head_dim=128) and it runs correctly — right shape, device and dtype. To get the repo in a clean state, here's what I'd suggest: Merge order:
To run the benchmark on your end: |
|
Thanks for the detailed write-up. We could consider switching the model to Sana, which is lighter and you could test locally, I'll share an updated script: https://huggingface.co/collections/Efficient-Large-Model/sana cc: @florianmattana |
|
Sounds good, Sana fits perfectly in 12 GB VRAM. Looking forward to the script. |
|
Here is the sana script, it is a bit extended since we are going to save visualization of the data recorded and the performance results, feel free to edit or extend it as per your usage: """
benchmark_sana_int8.py
======================
Benchmarks the INT8 fused attention kernel against the default Sana attention
processors on Efficient-Large-Model/Sana_600M_512px.
Three configurations are compared
──────────────────────────────────
1. default — Original model (SanaLinearAttnProcessor2_0 for self-attn,
SanaAttnProcessor2_0 for cross-attn)
2. sdpa_all — All attention replaced with F.scaled_dot_product_attention
(fair O(N²) baseline for the INT8 kernel)
3. int8 — Self-attention replaced with the INT8 kernel;
cross-attention kept as SDPA
Usage
─────
python benchmark_sana_int8.py \
--kernel-path ./torch-ext/attention_int8 \
--steps 4 \
--runs 5 \
--warmup 2
Results are written to ./benchmark_results/
"""
import argparse
import os
import time
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import FancyBboxPatch
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
# ─────────────────────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--kernel-path", default="./torch-ext/attention_int8",
help="Local path to the int8_attn kernel repository")
p.add_argument("--package-name", default="int8_attn",
help="Package name inside the kernel repo")
p.add_argument("--model-id", default="Efficient-Large-Model/Sana_600M_512px")
p.add_argument("--dtype", default="bf16",
choices=["fp16", "bf16", "fp32"])
p.add_argument("--steps", type=int, default=4)
p.add_argument("--runs", type=int, default=5)
p.add_argument("--warmup", type=int, default=2)
p.add_argument("--prompt",
default="a futuristic city at sunset, ultra-detailed, photorealistic")
p.add_argument("--out-dir", default="./benchmark_results")
p.add_argument("--device", default="cuda")
return p.parse_args()
# ─────────────────────────────────────────────────────────────────────────────
# Attention Processors
# ─────────────────────────────────────────────────────────────────────────────
class SanaSDPAProcessor:
"""
Drop-in replacement that runs standard scaled_dot_product_attention for
*both* self-attention and cross-attention blocks in SanaTransformerBlock.
Used as the O(N²) baseline against which the INT8 kernel is compared.
"""
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
if attention_mask is not None:
seq_len = hidden_states.shape[1]
attention_mask = attn.prepare_attention_mask(attention_mask, seq_len, batch_size)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states)
kv_src = hidden_states if encoder_hidden_states is None else encoder_hidden_states
key = attn.to_k(kv_src)
value = attn.to_v(kv_src)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
head_dim = key.shape[-1] // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = (
hidden_states.transpose(1, 2)
.reshape(batch_size, -1, attn.heads * head_dim)
.to(query.dtype)
)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class SanaInt8AttnProcessor:
"""
Replaces self-attention with the INT8 fused kernel.
Falls back to F.scaled_dot_product_attention for cross-attention because
the kernel requires Q and K/V to share the same sequence length.
Args:
kernel_fn: the ``int8_attention`` symbol loaded from the kernel module.
Signature (from the C++ binding):
int8_attention(Q, K, V, timestep_scales, timestep, causal)
-> Tensor
where Q/K/V are [B, H, N, D] float16/bf16/fp32 CUDA tensors.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
# Undefined tensor = no per-timestep scale (ts=1 inside the kernel)
self._no_ts = torch.Tensor()
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
# ── Cross-attention: fall back to SDPA (different seq lengths) ──────
if encoder_hidden_states is not None:
return self._sdpa_cross(attn, hidden_states, encoder_hidden_states,
attention_mask, batch_size)
# ── Self-attention: INT8 kernel path ─────────────────────────────────
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
head_dim = key.shape[-1] // attn.heads
# [B, N, H·D] → [B, H, N, D] (kernel expects this layout)
query = query.view(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3).contiguous()
key = key.view(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3).contiguous()
value = value.view(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3).contiguous()
# Kernel requires fp16; cast if necessary
orig_dtype = query.dtype
if orig_dtype != torch.float16:
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)
out = self.kernel_fn(
query, key, value,
self._no_ts, # timestep_scales – undefined → ts=1
0, # timestep index
False, # causal=False for DiT self-attention
)
# Cast back to original dtype
if orig_dtype != torch.float16:
out = out.to(orig_dtype)
# [B, H, N, D] → [B, N, H·D]
out = out.permute(0, 2, 1, 3).contiguous().reshape(batch_size, -1, attn.heads * head_dim)
out = attn.to_out[0](out)
out = attn.to_out[1](out)
out = out / attn.rescale_output_factor
return out
# ------------------------------------------------------------------
def _sdpa_cross(self, attn, hidden_states, encoder_hidden_states,
attention_mask, batch_size):
seq_len = hidden_states.shape[1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, seq_len, batch_size)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
head_dim = key.shape[-1] // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(
query, key, value,
attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
)
out = out.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
out = attn.to_out[0](out)
out = attn.to_out[1](out)
out = out / attn.rescale_output_factor
return out
# ─────────────────────────────────────────────────────────────────────────────
# Model loader
# ─────────────────────────────────────────────────────────────────────────────
def load_pipeline(model_id: str, dtype: torch.dtype, device: str):
try:
from diffusers import SanaPipeline
pipe = SanaPipeline.from_pretrained(model_id, torch_dtype=dtype)
except ImportError:
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=True)
return pipe
# ─────────────────────────────────────────────────────────────────────────────
# Processor installation helpers
# ─────────────────────────────────────────────────────────────────────────────
def restore_default(pipe):
"""Reinstall the processors that ship with the model."""
from diffusers.models.attention_processor import AttnProcessor
# reset_attn_processor() is available in recent diffusers
if hasattr(pipe.transformer, "reset_attn_processor"):
pipe.transformer.reset_attn_processor()
else:
# Fallback: nothing to do; the model was loaded with its defaults
pass
def set_sdpa_all(pipe):
"""Set SanaSDPAProcessor on every attention module."""
pipe.transformer.set_attn_processor(SanaSDPAProcessor())
def set_int8(pipe, kernel_fn):
"""
Set SanaInt8AttnProcessor on every attention module.
The processor itself decides whether to use the kernel (self-attn)
or fall back to SDPA (cross-attn).
"""
pipe.transformer.set_attn_processor(SanaInt8AttnProcessor(kernel_fn))
# ─────────────────────────────────────────────────────────────────────────────
# Benchmark runner
# ─────────────────────────────────────────────────────────────────────────────
@torch.no_grad()
def run_benchmark(pipe, prompt: str, steps: int, runs: int, warmup: int,
device: str) -> dict:
"""
Returns a dict with keys:
times_s – np.ndarray of per-run wall-clock seconds
peak_mem_gb – peak GPU memory in GB
images – list of PIL.Image (one per run)
"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
# Warm-up
for _ in range(warmup):
pipe(prompt, num_inference_steps=steps)
torch.cuda.synchronize(device)
times = []
images = []
for _ in range(runs):
torch.cuda.synchronize(device)
t0 = time.perf_counter()
out = pipe(prompt, num_inference_steps=steps)
torch.cuda.synchronize(device)
times.append(time.perf_counter() - t0)
images.append(out.images[0])
peak_mem = torch.cuda.max_memory_allocated(device) / 1e9
return {
"times_s": np.array(times),
"peak_mem_gb": peak_mem,
"images": images,
}
# ─────────────────────────────────────────────────────────────────────────────
# GPU info helper
# ─────────────────────────────────────────────────────────────────────────────
def gpu_name() -> str:
if not torch.cuda.is_available():
return "CPU"
return torch.cuda.get_device_name(0)
# ─────────────────────────────────────────────────────────────────────────────
# Pixel-level accuracy helper
# ─────────────────────────────────────────────────────────────────────────────
def image_diff(img_a: Image.Image, img_b: Image.Image) -> np.ndarray:
"""Return absolute pixel difference array (H, W, 3) uint8."""
a = np.array(img_a).astype(np.float32)
b = np.array(img_b).astype(np.float32)
return np.abs(a - b)
# ─────────────────────────────────────────────────────────────────────────────
# Plotting
# ─────────────────────────────────────────────────────────────────────────────
# ── Colour palette ──────────────────────────────────────────────────────────
PALETTE = {
"default": "#4C72B0",
"sdpa_all": "#DD8452",
"int8": "#55A868",
}
LABELS = {
"default": "Default\n(LinearAttn + SDPA)",
"sdpa_all": "SDPA Baseline\n(all SDPA)",
"int8": "INT8 Kernel\n(self: INT8, cross: SDPA)",
}
def plot_results(results: dict, out_dir: Path, gpu: str, steps: int):
configs = list(results.keys())
colors = [PALETTE[c] for c in configs]
disp_lbls = [LABELS[c] for c in configs]
means = np.array([results[c]["times_s"].mean() for c in configs])
stds = np.array([results[c]["times_s"].std() for c in configs])
p95s = np.array([np.percentile(results[c]["times_s"], 95) for c in configs])
mems = np.array([results[c]["peak_mem_gb"] for c in configs])
# Reference = sdpa_all if present, else default
ref_key = "sdpa_all" if "sdpa_all" in results else "default"
ref_mean = results[ref_key]["times_s"].mean()
speedups = ref_mean / means # >1 means faster than reference
# ── Figure layout ────────────────────────────────────────────────────────
fig = plt.figure(figsize=(18, 14), facecolor="#0F1117")
fig.suptitle(
f"Sana 600M · INT8 Attention Kernel Benchmark\n"
f"GPU: {gpu} | Steps: {steps} | Runs: {len(results[configs[0]]['times_s'])}",
fontsize=15, fontweight="bold", color="white", y=0.98,
)
gs = gridspec.GridSpec(
3, 3,
figure=fig,
hspace=0.52, wspace=0.38,
top=0.91, bottom=0.07, left=0.07, right=0.97,
)
AX_STYLE = dict(facecolor="#1A1D27", labelcolor="white",
titlecolor="white")
def styled_ax(ax):
ax.set_facecolor("#1A1D27")
ax.tick_params(colors="white")
ax.xaxis.label.set_color("white")
ax.yaxis.label.set_color("white")
ax.title.set_color("white")
for spine in ax.spines.values():
spine.set_edgecolor("#3A3D4A")
return ax
x = np.arange(len(configs))
bar_w = 0.5
# ── 1. Mean latency (bar + error) ────────────────────────────────────────
ax1 = styled_ax(fig.add_subplot(gs[0, 0]))
bars = ax1.bar(x, means, bar_w, yerr=stds, capsize=5,
color=colors, edgecolor="white", linewidth=0.6,
error_kw={"ecolor": "white", "linewidth": 1.2})
ax1.set_xticks(x); ax1.set_xticklabels(disp_lbls, fontsize=7.5)
ax1.set_ylabel("Seconds"); ax1.set_title("Mean Latency ± 1σ")
for bar, val in zip(bars, means):
ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + stds[bars.patches.index(bar)] + 0.01,
f"{val:.2f}s", ha="center", va="bottom", color="white", fontsize=8)
# ── 2. P95 latency ───────────────────────────────────────────────────────
ax2 = styled_ax(fig.add_subplot(gs[0, 1]))
bars2 = ax2.bar(x, p95s, bar_w, color=colors, edgecolor="white", linewidth=0.6)
ax2.set_xticks(x); ax2.set_xticklabels(disp_lbls, fontsize=7.5)
ax2.set_ylabel("Seconds"); ax2.set_title("P95 Latency")
for bar, val in zip(bars2, p95s):
ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
f"{val:.2f}s", ha="center", va="bottom", color="white", fontsize=8)
# ── 3. Peak GPU memory ───────────────────────────────────────────────────
ax3 = styled_ax(fig.add_subplot(gs[0, 2]))
bars3 = ax3.bar(x, mems, bar_w, color=colors, edgecolor="white", linewidth=0.6)
ax3.set_xticks(x); ax3.set_xticklabels(disp_lbls, fontsize=7.5)
ax3.set_ylabel("GB"); ax3.set_title("Peak GPU Memory")
for bar, val in zip(bars3, mems):
ax3.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.03,
f"{val:.2f}GB", ha="center", va="bottom", color="white", fontsize=8)
# ── 4. Per-run latency line chart ────────────────────────────────────────
ax4 = styled_ax(fig.add_subplot(gs[1, :2]))
run_idx = np.arange(1, len(results[configs[0]]["times_s"]) + 1)
for cfg in configs:
ts = results[cfg]["times_s"]
ax4.plot(run_idx, ts, marker="o", color=PALETTE[cfg],
label=LABELS[cfg].replace("\n", " "), linewidth=1.8, markersize=5)
ax4.set_xlabel("Run #"); ax4.set_ylabel("Seconds")
ax4.set_title("Per-Run Latency")
ax4.legend(fontsize=7.5, framealpha=0.2, labelcolor="white",
facecolor="#1A1D27", edgecolor="#3A3D4A")
ax4.set_xticks(run_idx)
# ── 5. Speedup vs. reference ─────────────────────────────────────────────
ax5 = styled_ax(fig.add_subplot(gs[1, 2]))
bar_colors_su = ["#55A868" if s >= 1.0 else "#C44E52" for s in speedups]
bars5 = ax5.bar(x, speedups, bar_w, color=bar_colors_su, edgecolor="white", linewidth=0.6)
ax5.axhline(1.0, color="white", linestyle="--", linewidth=0.9, alpha=0.5)
ax5.set_xticks(x); ax5.set_xticklabels(disp_lbls, fontsize=7.5)
ax5.set_ylabel("×"); ax5.set_title(f"Speedup vs. '{ref_key}'")
for bar, val in zip(bars5, speedups):
ax5.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.01 if val >= 0 else bar.get_height() - 0.05,
f"{val:.2f}×", ha="center", va="bottom", color="white", fontsize=8)
# ── 6 & 7. Generated image samples + diff heatmap ────────────────────────
# Show the first generated image for each config in a row
n_imgs = min(len(configs), 3)
for i, cfg in enumerate(configs[:n_imgs]):
ax_img = styled_ax(fig.add_subplot(gs[2, i]))
img = results[cfg]["images"][0]
ax_img.imshow(img)
ax_img.set_title(f"{LABELS[cfg].replace(chr(10), ' ')}\nsample output",
fontsize=8)
ax_img.axis("off")
# If int8 is present, overlay mean-diff heatmap in the last panel
if "int8" in results and n_imgs < 3:
ref_img = results[ref_key]["images"][0]
int8_img = results["int8"]["images"][0]
diff = image_diff(ref_img, int8_img).mean(axis=-1)
ax_diff = styled_ax(fig.add_subplot(gs[2, n_imgs]))
im = ax_diff.imshow(diff, cmap="inferno")
ax_diff.set_title(f"Pixel diff\n(INT8 vs {ref_key})", fontsize=8)
ax_diff.axis("off")
cbar = fig.colorbar(im, ax=ax_diff, fraction=0.046, pad=0.04)
cbar.ax.yaxis.set_tick_params(color="white")
plt.setp(cbar.ax.yaxis.get_ticklabels(), color="white")
out_path = out_dir / "benchmark_sana_int8.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close(fig)
print(f"[plot] Saved → {out_path}")
return out_path
def plot_timing_distribution(results: dict, out_dir: Path):
"""Violin / box-plot of run-time distributions."""
configs = list(results.keys())
data = [results[c]["times_s"] for c in configs]
labels = [LABELS[c].replace("\n", " ") for c in configs]
colors = [PALETTE[c] for c in configs]
fig, ax = plt.subplots(figsize=(9, 5), facecolor="#0F1117")
ax.set_facecolor("#1A1D27")
ax.tick_params(colors="white"); ax.xaxis.label.set_color("white")
ax.yaxis.label.set_color("white"); ax.title.set_color("white")
for sp in ax.spines.values():
sp.set_edgecolor("#3A3D4A")
vp = ax.violinplot(data, positions=range(len(configs)), showmedians=True,
showextrema=True)
for i, (body, col) in enumerate(zip(vp["bodies"], colors)):
body.set_facecolor(col); body.set_alpha(0.7)
vp["cmedians"].set_color("white")
vp["cmins"].set_color("white"); vp["cmaxes"].set_color("white")
vp["cbars"].set_color("white")
ax.set_xticks(range(len(configs)))
ax.set_xticklabels(labels, fontsize=8)
ax.set_ylabel("Seconds")
ax.set_title("Latency Distribution (violin)")
fig.tight_layout()
out_path = out_dir / "timing_distribution.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close(fig)
print(f"[plot] Saved → {out_path}")
return out_path
# ─────────────────────────────────────────────────────────────────────────────
# Numeric summary
# ─────────────────────────────────────────────────────────────────────────────
def print_summary(results: dict, ref_key: str):
print("\n" + "=" * 72)
print(f"{'Config':<22} {'Mean':>8} {'P50':>8} {'P95':>8} {'Std':>8} {'Mem(GB)':>9} {'Speedup':>9}")
print("-" * 72)
ref_mean = results[ref_key]["times_s"].mean()
for cfg, res in results.items():
ts = res["times_s"]
su = ref_mean / ts.mean()
print(
f"{cfg:<22} {ts.mean():>7.3f}s {np.median(ts):>7.3f}s "
f"{np.percentile(ts, 95):>7.3f}s {ts.std():>7.3f}s "
f"{res['peak_mem_gb']:>8.2f} {su:>8.2f}×"
)
print("=" * 72)
if "int8" in results:
su = ref_mean / results["int8"]["times_s"].mean()
tag = "faster" if su > 1 else "slower"
print(f"\n INT8 kernel is {abs(su):.2f}× {tag} than '{ref_key}' on average.\n")
# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
dtype = dtype_map[args.dtype]
print(f"[init] Loading pipeline: {args.model_id} dtype={args.dtype} device={args.device}")
pipe = load_pipeline(args.model_id, dtype, args.device)
gpu = gpu_name()
print(f"[init] GPU: {gpu}")
# ── Load INT8 kernel ──────────────────────────────────────────────────────
from kernels import get_local_kernel
kernel_path = Path(args.kernel_path)
print(f"[init] Loading kernel from {kernel_path}")
kernel_mod = get_local_kernel(kernel_path, args.package_name)
kernel_fn = kernel_mod.int8_attention # C++ symbol
# ── Benchmark configurations ──────────────────────────────────────────────
results = {}
# 1. Default (LinearAttn self-attn + SDPA cross-attn)
print("\n[bench] Configuration: default (original model processors)")
restore_default(pipe)
results["default"] = run_benchmark(
pipe, args.prompt, args.steps, args.runs, args.warmup, args.device
)
# 2. SDPA everywhere (O(N²) fair baseline)
print("\n[bench] Configuration: sdpa_all (F.scaled_dot_product_attention everywhere)")
set_sdpa_all(pipe)
results["sdpa_all"] = run_benchmark(
pipe, args.prompt, args.steps, args.runs, args.warmup, args.device
)
# 3. INT8 kernel for self-attention
print("\n[bench] Configuration: int8 (INT8 kernel for self-attn, SDPA for cross-attn)")
set_int8(pipe, kernel_fn)
results["int8"] = run_benchmark(
pipe, args.prompt, args.steps, args.runs, args.warmup, args.device
)
# ── Summary ───────────────────────────────────────────────────────────────
ref_key = "sdpa_all"
print_summary(results, ref_key)
# ── Accuracy metrics ──────────────────────────────────────────────────────
if "sdpa_all" in results and "int8" in results:
diff = image_diff(results["sdpa_all"]["images"][0], results["int8"]["images"][0])
print(f"\n[accuracy] INT8 vs SDPA_all — mean pixel diff: {diff.mean():.4f} "
f"max: {diff.max():.1f} (range 0–255)")
# ── Plots ─────────────────────────────────────────────────────────────────
main_plot = plot_results(results, out_dir, gpu, args.steps)
dist_plot = plot_timing_distribution(results, out_dir)
# Save individual generated images
for cfg, res in results.items():
img_path = out_dir / f"sample_{cfg}.png"
res["images"][0].save(img_path)
print(f"[image] Saved → {img_path}")
print(f"\n[done] All results in {out_dir.resolve()}")
if __name__ == "__main__":
main() |
|
Secondly, since we would be pushing this corrected kernel to HF in order for be used with cc: @florianmattana |
|
Hi @ParagEkbote , there are couple of issues with the script slowing down the test result. I ill do my best to fix it asap |
|
Hey Parag, the timestep_scales fix didn't resolve the segfault. I added some debug prints and the kernel is receiving valid inputs: Q: torch.Size([2, 70, 1024, 32]) float16 cuda:0 HEAD_DIM=32 is in the supported list, shapes look correct. The crash happens inside the kernel itself, likely a race condition in the online softmax accumulator — out_acc is written by multiple warps without explicit synchronization between them. This needs a fix in attention_int8.cu before we can benchmark. Let me know if you want me to dig into it or if you'd prefer to handle it on your end. |
Feel free to take a closer look and open a PR, the goal of publishing the correct kernel to HF Hub and the Medium article is not changed. I will also try to take a look if I can correct the issue. cc: @florianmattana |


Applied global K scale: instead of computing max_abs per K tile inside the loop,
We now scan all of K once before the loop to get a single scale.
This keeps quantization consistent across tiles for the online softmax.
Results with N=64 (1 tile, no real difference expected):
Mean diff: 0.163728 Max diff: 1.301880
Results with N=512 (8 tiles):
Mean diff: 0.063129 Max diff: 0.534424
Results with N=2048 (32 tiles):
Mean diff: 0.032088 Max diff: 0.288895
The bigger the sequence, the more the fix helps.
N=2048 is more realistic for diffusion models and we're under the 0.05 mean threshold.
I guess we should try with proper data now.