Skip to content

Fix/global max scale#8

Open
florianmattana wants to merge 4 commits intoParagEkbote:mainfrom
florianmattana:fix/global_max_scale
Open

Fix/global max scale#8
florianmattana wants to merge 4 commits intoParagEkbote:mainfrom
florianmattana:fix/global_max_scale

Conversation

@florianmattana
Copy link
Copy Markdown
Contributor

Applied global K scale: instead of computing max_abs per K tile inside the loop,
We now scan all of K once before the loop to get a single scale.
This keeps quantization consistent across tiles for the online softmax.

Results with N=64 (1 tile, no real difference expected):
Mean diff: 0.163728 Max diff: 1.301880

Results with N=512 (8 tiles):
Mean diff: 0.063129 Max diff: 0.534424

Results with N=2048 (32 tiles):
Mean diff: 0.032088 Max diff: 0.288895

The bigger the sequence, the more the fix helps.
N=2048 is more realistic for diffusion models and we're under the 0.05 mean threshold.

I guess we should try with proper data now.

@ParagEkbote
Copy link
Copy Markdown
Owner

Thanks for investigating the k-scale issue. I have a couple of points of queries before we begin to share test with a real example.

First, do the files in cmake directory along with CMakeLists.txt help with building the kernel in pure PyTorch rather than Nix, or did you use these files along with Nix to build the kernel, could you briefly explain the build/compiling method applied.

Secondly, I could you please move the test_simple.py and compat.py script into a seperate dir like test_kernel to keep the repo organized. Same with the metadata.json files into metadata dir.

Could you also add .pre-commit-config.yaml at the repo root to check the code and lint the files. For python and toml specifically, use the ruff hook, standard pre-commit hooks, for the cuda and c++ binding code, feel free to add the hooks you use on a daily basis.

I will also post a real example here with that you can try and experiment with to record benchmark data for the kernel with figures.

cc: @florianmattana

@florianmattana
Copy link
Copy Markdown
Contributor Author

For the build method: I didn't use Nix. The reason is I wanted a setup that works with a standard Python venv + PyTorch, without requiring contributors to install Nix.(or myself since I noticed it was something like 30Go ?)

Here's what I did:

Installed build2cmake (a Rust tool) via cargo install build2cmake
Ran build2cmake generate-torch build.toml --force. This reads build.toml and generates CMakeLists.txt, setup.py, pyproject.toml, and registration.h automatically
Then just pip install --no-build-isolation -e . This this triggers CMake + nvcc, compiles the CUDA kernel, and produces the .so that torch.ops loads at runtime
The cmake/ directory files are all generated by step 2, not written manually.

I chose this approach because it keeps the build pure PyTorch anyone with CUDA and PyTorch installed can build the kernel in one command without extra tooling. But if you'd prefer sticking with Nix or a different build system, I'm happy to adapt. What works best for you?

Regarding test_simple.py, compat.py and metadata.jso did I push it on it any branch unintionally ?
I checked all branches and these files were never pushed to the repo but I can be mistaken. If I didnt lets just ignore it from now on if we move on to real data. I just wanted to use a quick reference.

Let me know if it sounds ok to you or Ill adapt.

@ParagEkbote
Copy link
Copy Markdown
Owner

For the build method: I didn't use Nix. The reason is I wanted a setup that works with a standard Python venv + PyTorch, without requiring contributors to install Nix.(or myself since I noticed it was something like 30Go ?)

Here's what I did:

Installed build2cmake (a Rust tool) via cargo install build2cmake Ran build2cmake generate-torch build.toml --force. This reads build.toml and generates CMakeLists.txt, setup.py, pyproject.toml, and registration.h automatically Then just pip install --no-build-isolation -e . This this triggers CMake + nvcc, compiles the CUDA kernel, and produces the .so that torch.ops loads at runtime The cmake/ directory files are all generated by step 2, not written manually.

I chose this approach because it keeps the build pure PyTorch anyone with CUDA and PyTorch installed can build the kernel in one command without extra tooling. But if you'd prefer sticking with Nix or a different build system, I'm happy to adapt. What works best for you?

Regarding test_simple.py, compat.py and metadata.jso did I push it on it any branch unintionally ? I checked all branches and these files were never pushed to the repo but I can be mistaken. If I didnt lets just ignore it from now on if we move on to real data. I just wanted to use a quick reference.

Let me know if it sounds ok to you or Ill adapt.

Well, since we be using the kernel with the get_kernel method as defined in the kernels API, we would need to adapt the kernel repo to save the built kernels under build directory as seen for flash-attention3 to ensure they work correctly.

In my experience thus far, building with Nix is less than ideal since it requires setup with cachix, nix cli and a set of other issues, so if we are able to publish the fixed kernel with torch, I think that can work as well.

If you can view the file diff for this PR, it should not be this large, so if you could either organize the files if you consider them to be useful, that would be quite helpful.

@ParagEkbote
Copy link
Copy Markdown
Owner

Here's an example script with flux-schnell that benchmarks the performance of the kernel against torch SDPA attention, I've also implemented a Custom Image Processor based on the flux_transformer ,feel free to modify the script to either improve the processor, add visual figures to the benchmarks:

import torch
import time
import numpy as np
from diffusers import DiffusionPipeline
from kernels import get_local_kernel
from pathlib import Path

# =====================================================
# CONFIG
# =====================================================
DEVICE = "cuda"
DTYPE = torch.bfloat16
PROMPT = "a futuristic city at sunset, ultra detailed"
NUM_STEPS = 8
RUNS = 4
WARMUP = 2

# =====================================================
# LOAD MODEL
# =====================================================
pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=DTYPE
).to(DEVICE)

pipe.set_progress_bar_config(disable=True)

# =====================================================
# LOAD YOUR KERNEL
# =====================================================
repo_path = Path("./torch-ext/attention_int8")
kernel_mod = get_local_kernel(repo_path, "int8_attn")
kernel_fn = kernel_mod.int8_attention_forward


# =====================================================
# CUSTOM PROCESSOR
# =====================================================
class FluxInt8AttnProcessor:
    def __init__(self, kernel_fn):
        self.kernel_fn = kernel_fn

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        image_rotary_emb=None,
    ):
        # ---- projections ----
        from diffusers.models.transformers.transformer_flux import _get_qkv_projections, apply_rotary_emb

        query, key, value, eq, ek, ev = _get_qkv_projections(
            attn, hidden_states, encoder_hidden_states
        )

        # ---- reshape ----
        query = query.unflatten(-1, (attn.heads, -1))
        key   = key.unflatten(-1, (attn.heads, -1))
        value = value.unflatten(-1, (attn.heads, -1))

        # ---- norm ----
        query = attn.norm_q(query)
        key   = attn.norm_k(key)

        # ---- joint attention ----
        if attn.added_kv_proj_dim is not None:
            eq = eq.unflatten(-1, (attn.heads, -1))
            ek = ek.unflatten(-1, (attn.heads, -1))
            ev = ev.unflatten(-1, (attn.heads, -1))

            eq = attn.norm_added_q(eq)
            ek = attn.norm_added_k(ek)

            query = torch.cat([eq, query], dim=1)
            key   = torch.cat([ek, key], dim=1)
            value = torch.cat([ev, value], dim=1)

        # ---- rotary ----
        if image_rotary_emb is not None:
            query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
            key   = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)

        # ---- layout convert ----
        query = query.permute(0, 2, 1, 3).contiguous()
        key   = key.permute(0, 2, 1, 3).contiguous()
        value = value.permute(0, 2, 1, 3).contiguous()

        # ---- kernel ----
        out = self.kernel_fn(query, key, value)

        # ---- back ----
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.flatten(2, 3)

        # ---- split ----
        if encoder_hidden_states is not None:
            enc_len = encoder_hidden_states.shape[1]
            enc, hid = out.split([enc_len, out.shape[1] - enc_len], dim=1)

            hid = attn.to_out[0](hid)
            hid = attn.to_out[1](hid)
            enc = attn.to_add_out(enc)

            return hid, enc

        return out


# =====================================================
# BENCHMARK UTIL
# =====================================================
def benchmark(name, apply_processor=None):
    if apply_processor:
        pipe.transformer.set_attn_processor(apply_processor)
    else:
        pipe.transformer.set_default_attn_processor()

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # warmup
    for _ in range(WARMUP):
        _ = pipe(PROMPT, num_inference_steps=NUM_STEPS).images

    torch.cuda.synchronize()

    times = []
    for _ in range(RUNS):
        start = time.time()
        out = pipe(PROMPT, num_inference_steps=NUM_STEPS).images
        torch.cuda.synchronize()
        times.append(time.time() - start)

    times = np.array(times)
    mem = torch.cuda.max_memory_allocated() / 1e9

    print(f"\n{name}")
    print(f"Mean: {times.mean():.3f}s | P95: {np.percentile(times,95):.3f}s | Std: {times.std():.3f}s")
    print(f"Peak memory: {mem:.2f} GB")

    return out, times.mean()


# =====================================================
# RUN
# =====================================================
print("=== Running Benchmarks ===")

# SDPA baseline
out_sdpa, t_sdpa = benchmark("SDPA (baseline)", None)

# Custom kernel
processor = FluxInt8AttnProcessor(kernel_fn)
out_custom, t_custom = benchmark("Custom INT8 Kernel", processor)


# =====================================================
# ACCURACY CHECK
# =====================================================
def compare_images(img1, img2):
    t1 = torch.tensor(np.array(img1)).float()
    t2 = torch.tensor(np.array(img2)).float()

    diff = (t1 - t2).abs()

    return {
        "mean": diff.mean().item(),
        "max": diff.max().item()
    }

metrics = compare_images(out_sdpa[0], out_custom[0])

print("\n=== Accuracy ===")
print(f"Mean pixel diff: {metrics['mean']:.4f}")
print(f"Max pixel diff : {metrics['max']:.4f}")

If this script works as expected, we should be good enough to get figure-based benchmarks or try to benchmark against SageAttention as well.

@florianmattana
Copy link
Copy Markdown
Contributor Author

For the build method: I didn't use Nix. The reason is I wanted a setup that works with a standard Python venv + PyTorch, without requiring contributors to install Nix.(or myself since I noticed it was something like 30Go ?)
Here's what I did:
Installed build2cmake (a Rust tool) via cargo install build2cmake Ran build2cmake generate-torch build.toml --force. This reads build.toml and generates CMakeLists.txt, setup.py, pyproject.toml, and registration.h automatically Then just pip install --no-build-isolation -e . This this triggers CMake + nvcc, compiles the CUDA kernel, and produces the .so that torch.ops loads at runtime The cmake/ directory files are all generated by step 2, not written manually.
I chose this approach because it keeps the build pure PyTorch anyone with CUDA and PyTorch installed can build the kernel in one command without extra tooling. But if you'd prefer sticking with Nix or a different build system, I'm happy to adapt. What works best for you?
Regarding test_simple.py, compat.py and metadata.jso did I push it on it any branch unintionally ? I checked all branches and these files were never pushed to the repo but I can be mistaken. If I didnt lets just ignore it from now on if we move on to real data. I just wanted to use a quick reference.
Let me know if it sounds ok to you or Ill adapt.

Well, since we be using the kernel with the get_kernel method as defined in the kernels API, we would need to adapt the kernel repo to save the built kernels under build directory as seen for flash-attention3 to ensure they work correctly.

In my experience thus far, building with Nix is less than ideal since it requires setup with cachix, nix cli and a set of other issues, so if we are able to publish the fixed kernel with torch, I think that can work as well.

If you can view the file diff for this PR, it should not be this large, so if you could either organize the files if you consider them to be useful, that would be quite helpful.

Hey Parag, I updated the PR. I removed the clang-format hook because that what the one that was reformatting the entire CUDA file and making the diff huge. The config now only has standard pre-commit hooks (trailing whitespace, end-of-file, yaml check) and ruff for Python/TOML. The diff is lesser now, just minor whitespace fixes.

For the build to make you understand I used build2cmake to generate CMakeLists.txt from build.toml, then pip install --no-build-isolation -e . which runs CMake + nvcc and produces the .so. No Nix involved. I went this route to keep things simple for anyone with just CUDA and PyTorch installed.
And honestly I didn't know about the HF Kernels API and get_kernel before your message, so thanks for flagging that. I'll look into how kernel-builder produces the build/ directory with all the required variants so we can publish the kernel on the Hub.

For the test files (test_simple.py, compat.py, metadata.json), I checked all branches and they were never pushed, they only existed on my previous machine. I can recreate them in a test_kernel/ directory if you want but I except if I am mistaken they are not needed anymore ?

@ParagEkbote
Copy link
Copy Markdown
Owner

Well, I can still see the test files (test_simple.py, compat.py, metadata.json) in the file diff, could you please re-check and remove these files. We'll write about the build or setup process in a seperate .md file for users.

image image

Would you mind resolving the merge conflicts in #8 and #10 due to the pre-commit changes which were pushed?

@ParagEkbote
Copy link
Copy Markdown
Owner

Here's an example script with flux-schnell that benchmarks the performance of the kernel against torch SDPA attention, I've also implemented a Custom Image Processor based on the flux_transformer ,feel free to modify the script to either improve the processor, add visual figures to the benchmarks:

import torch
import time
import numpy as np
from diffusers import DiffusionPipeline
from kernels import get_local_kernel
from pathlib import Path

# =====================================================
# CONFIG
# =====================================================
DEVICE = "cuda"
DTYPE = torch.bfloat16
PROMPT = "a futuristic city at sunset, ultra detailed"
NUM_STEPS = 8
RUNS = 4
WARMUP = 2

# =====================================================
# LOAD MODEL
# =====================================================
pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=DTYPE
).to(DEVICE)

pipe.set_progress_bar_config(disable=True)

# =====================================================
# LOAD YOUR KERNEL
# =====================================================
repo_path = Path("./torch-ext/attention_int8")
kernel_mod = get_local_kernel(repo_path, "int8_attn")
kernel_fn = kernel_mod.int8_attention_forward


# =====================================================
# CUSTOM PROCESSOR
# =====================================================
class FluxInt8AttnProcessor:
    def __init__(self, kernel_fn):
        self.kernel_fn = kernel_fn

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        image_rotary_emb=None,
    ):
        # ---- projections ----
        from diffusers.models.transformers.transformer_flux import _get_qkv_projections, apply_rotary_emb

        query, key, value, eq, ek, ev = _get_qkv_projections(
            attn, hidden_states, encoder_hidden_states
        )

        # ---- reshape ----
        query = query.unflatten(-1, (attn.heads, -1))
        key   = key.unflatten(-1, (attn.heads, -1))
        value = value.unflatten(-1, (attn.heads, -1))

        # ---- norm ----
        query = attn.norm_q(query)
        key   = attn.norm_k(key)

        # ---- joint attention ----
        if attn.added_kv_proj_dim is not None:
            eq = eq.unflatten(-1, (attn.heads, -1))
            ek = ek.unflatten(-1, (attn.heads, -1))
            ev = ev.unflatten(-1, (attn.heads, -1))

            eq = attn.norm_added_q(eq)
            ek = attn.norm_added_k(ek)

            query = torch.cat([eq, query], dim=1)
            key   = torch.cat([ek, key], dim=1)
            value = torch.cat([ev, value], dim=1)

        # ---- rotary ----
        if image_rotary_emb is not None:
            query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
            key   = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)

        # ---- layout convert ----
        query = query.permute(0, 2, 1, 3).contiguous()
        key   = key.permute(0, 2, 1, 3).contiguous()
        value = value.permute(0, 2, 1, 3).contiguous()

        # ---- kernel ----
        out = self.kernel_fn(query, key, value)

        # ---- back ----
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.flatten(2, 3)

        # ---- split ----
        if encoder_hidden_states is not None:
            enc_len = encoder_hidden_states.shape[1]
            enc, hid = out.split([enc_len, out.shape[1] - enc_len], dim=1)

            hid = attn.to_out[0](hid)
            hid = attn.to_out[1](hid)
            enc = attn.to_add_out(enc)

            return hid, enc

        return out


# =====================================================
# BENCHMARK UTIL
# =====================================================
def benchmark(name, apply_processor=None):
    if apply_processor:
        pipe.transformer.set_attn_processor(apply_processor)
    else:
        pipe.transformer.set_default_attn_processor()

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # warmup
    for _ in range(WARMUP):
        _ = pipe(PROMPT, num_inference_steps=NUM_STEPS).images

    torch.cuda.synchronize()

    times = []
    for _ in range(RUNS):
        start = time.time()
        out = pipe(PROMPT, num_inference_steps=NUM_STEPS).images
        torch.cuda.synchronize()
        times.append(time.time() - start)

    times = np.array(times)
    mem = torch.cuda.max_memory_allocated() / 1e9

    print(f"\n{name}")
    print(f"Mean: {times.mean():.3f}s | P95: {np.percentile(times,95):.3f}s | Std: {times.std():.3f}s")
    print(f"Peak memory: {mem:.2f} GB")

    return out, times.mean()


# =====================================================
# RUN
# =====================================================
print("=== Running Benchmarks ===")

# SDPA baseline
out_sdpa, t_sdpa = benchmark("SDPA (baseline)", None)

# Custom kernel
processor = FluxInt8AttnProcessor(kernel_fn)
out_custom, t_custom = benchmark("Custom INT8 Kernel", processor)


# =====================================================
# ACCURACY CHECK
# =====================================================
def compare_images(img1, img2):
    t1 = torch.tensor(np.array(img1)).float()
    t2 = torch.tensor(np.array(img2)).float()

    diff = (t1 - t2).abs()

    return {
        "mean": diff.mean().item(),
        "max": diff.max().item()
    }

metrics = compare_images(out_sdpa[0], out_custom[0])

print("\n=== Accuracy ===")
print(f"Mean pixel diff: {metrics['mean']:.4f}")
print(f"Max pixel diff : {metrics['max']:.4f}")

If this script works as expected, we should be good enough to get figure-based benchmarks or try to benchmark against SageAttention as well.

Also, feel free to let me know you would like to go with a different model or take a different approach in the benchmarking script.

@ParagEkbote
Copy link
Copy Markdown
Owner

@florianmattana Hey, how is the testing/benchmarking of the kernel going? Would you like any additional resources or help from me?

@florianmattana
Copy link
Copy Markdown
Contributor Author

@florianmattana Hey, how is the testing/benchmarking of the kernel going? Would you like any additional resources or help from me?

Hi Parag, sorry I am very preoccupied these days. A lot of request from everywhere. I wont commit too much but Ill try before the end of the week.

Beside testing did you need help with anything else ?

@ParagEkbote
Copy link
Copy Markdown
Owner

Well, once we are able to benchmark the results, merge this PR and push the corrected kernel to HF Hub, I would like to write an article on Medium documenting all of the fixes you have applied and as a result, the performance gains we have achieved for diffusion models. You can view some of my previous technical content which I have posted:

Article 1

Article 2

I will take care of the entire writing part, you would only need to provide me the benchmark numbers and some cool figures that we could use in it and optionally give feedback if you'd like. WDYT?

cc: @florianmattana

@florianmattana
Copy link
Copy Markdown
Contributor Author

Ok ill try to do asap. Hopefully before the end of the week

@florianmattana
Copy link
Copy Markdown
Contributor Author

florianmattana commented Mar 26, 2026

Hey Parag, here's what was done today:

  1. Resolved merge conflicts on both PR Fix/global max scale #8 and PR add pre-commit config and apply minor whitespace fixes #10 by rebasing on main
  2. Removed test files (test_simple.py, compat.py, metadata-*.json) from the diff — they were never meant to be tracked
  3. Added benchmark.py to test the kernel against SDPA baseline on FLUX.1-schnell

For the benchmark: tried running it locally (RTX 5070 Ti, 12 GB VRAM) but FLUX doesn't fit in VRAM. Sequential CPU offload works but gives ~50s/inference which makes the numbers meaningless. Tested the kernel directly with synthetic tensors (batch=1, heads=24, seq=4096, head_dim=128) and it runs correctly — right shape, device and dtype.


To get the repo in a clean state, here's what I'd suggest:

Merge order:

  1. Merge PR add pre-commit config and apply minor whitespace fixes #10 first (pre-commit config) — it's a clean chore with no logic changes
  2. Then merge PR Fix/global max scale #8 (global K scale fix) — the actual fix

To run the benchmark on your end:
git clone https://github.com/ParagEkbote/model-kernels.git
cd model-kernels/kernels-v1/attention-int8
cargo install build2cmake
build2cmake generate-torch build.toml --force
pip install --no-build-isolation -e .
pip install torch diffusers transformers accelerate kernels
python3 benchmark.py

@ParagEkbote
Copy link
Copy Markdown
Owner

ParagEkbote commented Mar 26, 2026

Thanks for the detailed write-up.

We could consider switching the model to Sana, which is lighter and you could test locally, I'll share an updated script: https://huggingface.co/collections/Efficient-Large-Model/sana

cc: @florianmattana

@florianmattana
Copy link
Copy Markdown
Contributor Author

Sounds good, Sana fits perfectly in 12 GB VRAM. Looking forward to the script.
@ParagEkbote

@ParagEkbote
Copy link
Copy Markdown
Owner

Here is the sana script, it is a bit extended since we are going to save visualization of the data recorded and the performance results, feel free to edit or extend it as per your usage:

"""
benchmark_sana_int8.py
======================
Benchmarks the INT8 fused attention kernel against the default Sana attention
processors on Efficient-Large-Model/Sana_600M_512px.

Three configurations are compared
──────────────────────────────────
  1. default   — Original model (SanaLinearAttnProcessor2_0 for self-attn,
                 SanaAttnProcessor2_0 for cross-attn)
  2. sdpa_all  — All attention replaced with F.scaled_dot_product_attention
                 (fair O(N²) baseline for the INT8 kernel)
  3. int8      — Self-attention replaced with the INT8 kernel;
                 cross-attention kept as SDPA

Usage
─────
  python benchmark_sana_int8.py \
      --kernel-path ./torch-ext/attention_int8 \
      --steps 4 \
      --runs 5 \
      --warmup 2

Results are written to ./benchmark_results/
"""

import argparse
import os
import time
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import FancyBboxPatch
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

# ─────────────────────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────────────────────

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--kernel-path",  default="./torch-ext/attention_int8",
                   help="Local path to the int8_attn kernel repository")
    p.add_argument("--package-name", default="int8_attn",
                   help="Package name inside the kernel repo")
    p.add_argument("--model-id",     default="Efficient-Large-Model/Sana_600M_512px")
    p.add_argument("--dtype",        default="bf16",
                   choices=["fp16", "bf16", "fp32"])
    p.add_argument("--steps",        type=int, default=4)
    p.add_argument("--runs",         type=int, default=5)
    p.add_argument("--warmup",       type=int, default=2)
    p.add_argument("--prompt",
                   default="a futuristic city at sunset, ultra-detailed, photorealistic")
    p.add_argument("--out-dir",      default="./benchmark_results")
    p.add_argument("--device",       default="cuda")
    return p.parse_args()


# ─────────────────────────────────────────────────────────────────────────────
# Attention Processors
# ─────────────────────────────────────────────────────────────────────────────

class SanaSDPAProcessor:
    """
    Drop-in replacement that runs standard scaled_dot_product_attention for
    *both* self-attention and cross-attention blocks in SanaTransformerBlock.
    Used as the O(N²) baseline against which the INT8 kernel is compared.
    """

    def __call__(
        self,
        attn,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        batch_size = hidden_states.shape[0]

        if attention_mask is not None:
            seq_len = hidden_states.shape[1]
            attention_mask = attn.prepare_attention_mask(attention_mask, seq_len, batch_size)
            attention_mask = attention_mask.view(
                batch_size, attn.heads, -1, attention_mask.shape[-1]
            )

        query = attn.to_q(hidden_states)
        kv_src = hidden_states if encoder_hidden_states is None else encoder_hidden_states
        key   = attn.to_k(kv_src)
        value = attn.to_v(kv_src)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        head_dim = key.shape[-1] // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key   = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = F.scaled_dot_product_attention(
            query, key, value,
            attn_mask=attention_mask,
            dropout_p=0.0,
            is_causal=False,
        )

        hidden_states = (
            hidden_states.transpose(1, 2)
            .reshape(batch_size, -1, attn.heads * head_dim)
            .to(query.dtype)
        )
        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        hidden_states = hidden_states / attn.rescale_output_factor
        return hidden_states


class SanaInt8AttnProcessor:
    """
    Replaces self-attention with the INT8 fused kernel.
    Falls back to F.scaled_dot_product_attention for cross-attention because
    the kernel requires Q and K/V to share the same sequence length.

    Args:
        kernel_fn: the ``int8_attention`` symbol loaded from the kernel module.
            Signature (from the C++ binding):
                int8_attention(Q, K, V, timestep_scales, timestep, causal)
                    -> Tensor
            where Q/K/V are [B, H, N, D] float16/bf16/fp32 CUDA tensors.
    """

    def __init__(self, kernel_fn):
        self.kernel_fn = kernel_fn
        # Undefined tensor = no per-timestep scale (ts=1 inside the kernel)
        self._no_ts = torch.Tensor()

    def __call__(
        self,
        attn,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        batch_size = hidden_states.shape[0]

        # ── Cross-attention: fall back to SDPA (different seq lengths) ──────
        if encoder_hidden_states is not None:
            return self._sdpa_cross(attn, hidden_states, encoder_hidden_states,
                                    attention_mask, batch_size)

        # ── Self-attention: INT8 kernel path ─────────────────────────────────
        query = attn.to_q(hidden_states)
        key   = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        head_dim = key.shape[-1] // attn.heads

        # [B, N, H·D] → [B, H, N, D]  (kernel expects this layout)
        query = query.view(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3).contiguous()
        key   = key.view(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3).contiguous()
        value = value.view(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3).contiguous()

        # Kernel requires fp16; cast if necessary
        orig_dtype = query.dtype
        if orig_dtype != torch.float16:
            query = query.to(torch.float16)
            key   = key.to(torch.float16)
            value = value.to(torch.float16)

        out = self.kernel_fn(
            query, key, value,
            self._no_ts,   # timestep_scales  – undefined → ts=1
            0,             # timestep index
            False,         # causal=False for DiT self-attention
        )

        # Cast back to original dtype
        if orig_dtype != torch.float16:
            out = out.to(orig_dtype)

        # [B, H, N, D] → [B, N, H·D]
        out = out.permute(0, 2, 1, 3).contiguous().reshape(batch_size, -1, attn.heads * head_dim)

        out = attn.to_out[0](out)
        out = attn.to_out[1](out)
        out = out / attn.rescale_output_factor
        return out

    # ------------------------------------------------------------------
    def _sdpa_cross(self, attn, hidden_states, encoder_hidden_states,
                    attention_mask, batch_size):
        seq_len = hidden_states.shape[1]
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, seq_len, batch_size)
            attention_mask = attention_mask.view(
                batch_size, attn.heads, -1, attention_mask.shape[-1]
            )

        query = attn.to_q(hidden_states)
        key   = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        head_dim = key.shape[-1] // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key   = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        out = F.scaled_dot_product_attention(
            query, key, value,
            attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
        )
        out = out.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
        out = attn.to_out[0](out)
        out = attn.to_out[1](out)
        out = out / attn.rescale_output_factor
        return out


# ─────────────────────────────────────────────────────────────────────────────
# Model loader
# ─────────────────────────────────────────────────────────────────────────────

def load_pipeline(model_id: str, dtype: torch.dtype, device: str):
    try:
        from diffusers import SanaPipeline
        pipe = SanaPipeline.from_pretrained(model_id, torch_dtype=dtype)
    except ImportError:
        from diffusers import DiffusionPipeline
        pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)

    pipe = pipe.to(device)
    pipe.set_progress_bar_config(disable=True)
    return pipe


# ─────────────────────────────────────────────────────────────────────────────
# Processor installation helpers
# ─────────────────────────────────────────────────────────────────────────────

def restore_default(pipe):
    """Reinstall the processors that ship with the model."""
    from diffusers.models.attention_processor import AttnProcessor
    # reset_attn_processor() is available in recent diffusers
    if hasattr(pipe.transformer, "reset_attn_processor"):
        pipe.transformer.reset_attn_processor()
    else:
        # Fallback: nothing to do; the model was loaded with its defaults
        pass


def set_sdpa_all(pipe):
    """Set SanaSDPAProcessor on every attention module."""
    pipe.transformer.set_attn_processor(SanaSDPAProcessor())


def set_int8(pipe, kernel_fn):
    """
    Set SanaInt8AttnProcessor on every attention module.
    The processor itself decides whether to use the kernel (self-attn)
    or fall back to SDPA (cross-attn).
    """
    pipe.transformer.set_attn_processor(SanaInt8AttnProcessor(kernel_fn))


# ─────────────────────────────────────────────────────────────────────────────
# Benchmark runner
# ─────────────────────────────────────────────────────────────────────────────

@torch.no_grad()
def run_benchmark(pipe, prompt: str, steps: int, runs: int, warmup: int,
                  device: str) -> dict:
    """
    Returns a dict with keys:
        times_s      – np.ndarray of per-run wall-clock seconds
        peak_mem_gb  – peak GPU memory in GB
        images       – list of PIL.Image (one per run)
    """
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)

    # Warm-up
    for _ in range(warmup):
        pipe(prompt, num_inference_steps=steps)
    torch.cuda.synchronize(device)

    times = []
    images = []

    for _ in range(runs):
        torch.cuda.synchronize(device)
        t0 = time.perf_counter()
        out = pipe(prompt, num_inference_steps=steps)
        torch.cuda.synchronize(device)
        times.append(time.perf_counter() - t0)
        images.append(out.images[0])

    peak_mem = torch.cuda.max_memory_allocated(device) / 1e9

    return {
        "times_s": np.array(times),
        "peak_mem_gb": peak_mem,
        "images": images,
    }


# ─────────────────────────────────────────────────────────────────────────────
# GPU info helper
# ─────────────────────────────────────────────────────────────────────────────

def gpu_name() -> str:
    if not torch.cuda.is_available():
        return "CPU"
    return torch.cuda.get_device_name(0)


# ─────────────────────────────────────────────────────────────────────────────
# Pixel-level accuracy helper
# ─────────────────────────────────────────────────────────────────────────────

def image_diff(img_a: Image.Image, img_b: Image.Image) -> np.ndarray:
    """Return absolute pixel difference array (H, W, 3) uint8."""
    a = np.array(img_a).astype(np.float32)
    b = np.array(img_b).astype(np.float32)
    return np.abs(a - b)


# ─────────────────────────────────────────────────────────────────────────────
# Plotting
# ─────────────────────────────────────────────────────────────────────────────

# ── Colour palette ──────────────────────────────────────────────────────────
PALETTE = {
    "default":  "#4C72B0",
    "sdpa_all": "#DD8452",
    "int8":     "#55A868",
}
LABELS = {
    "default":  "Default\n(LinearAttn + SDPA)",
    "sdpa_all": "SDPA Baseline\n(all SDPA)",
    "int8":     "INT8 Kernel\n(self: INT8, cross: SDPA)",
}


def plot_results(results: dict, out_dir: Path, gpu: str, steps: int):
    configs   = list(results.keys())
    colors    = [PALETTE[c] for c in configs]
    disp_lbls = [LABELS[c]  for c in configs]

    means     = np.array([results[c]["times_s"].mean() for c in configs])
    stds      = np.array([results[c]["times_s"].std()  for c in configs])
    p95s      = np.array([np.percentile(results[c]["times_s"], 95) for c in configs])
    mems      = np.array([results[c]["peak_mem_gb"]  for c in configs])

    # Reference = sdpa_all if present, else default
    ref_key  = "sdpa_all" if "sdpa_all" in results else "default"
    ref_mean = results[ref_key]["times_s"].mean()
    speedups = ref_mean / means  # >1 means faster than reference

    # ── Figure layout ────────────────────────────────────────────────────────
    fig = plt.figure(figsize=(18, 14), facecolor="#0F1117")
    fig.suptitle(
        f"Sana 600M  ·  INT8 Attention Kernel Benchmark\n"
        f"GPU: {gpu}   |   Steps: {steps}   |   Runs: {len(results[configs[0]]['times_s'])}",
        fontsize=15, fontweight="bold", color="white", y=0.98,
    )

    gs = gridspec.GridSpec(
        3, 3,
        figure=fig,
        hspace=0.52, wspace=0.38,
        top=0.91, bottom=0.07, left=0.07, right=0.97,
    )

    AX_STYLE = dict(facecolor="#1A1D27", labelcolor="white",
                    titlecolor="white")

    def styled_ax(ax):
        ax.set_facecolor("#1A1D27")
        ax.tick_params(colors="white")
        ax.xaxis.label.set_color("white")
        ax.yaxis.label.set_color("white")
        ax.title.set_color("white")
        for spine in ax.spines.values():
            spine.set_edgecolor("#3A3D4A")
        return ax

    x = np.arange(len(configs))
    bar_w = 0.5

    # ── 1. Mean latency (bar + error) ────────────────────────────────────────
    ax1 = styled_ax(fig.add_subplot(gs[0, 0]))
    bars = ax1.bar(x, means, bar_w, yerr=stds, capsize=5,
                   color=colors, edgecolor="white", linewidth=0.6,
                   error_kw={"ecolor": "white", "linewidth": 1.2})
    ax1.set_xticks(x); ax1.set_xticklabels(disp_lbls, fontsize=7.5)
    ax1.set_ylabel("Seconds"); ax1.set_title("Mean Latency ± 1σ")
    for bar, val in zip(bars, means):
        ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + stds[bars.patches.index(bar)] + 0.01,
                 f"{val:.2f}s", ha="center", va="bottom", color="white", fontsize=8)

    # ── 2. P95 latency ───────────────────────────────────────────────────────
    ax2 = styled_ax(fig.add_subplot(gs[0, 1]))
    bars2 = ax2.bar(x, p95s, bar_w, color=colors, edgecolor="white", linewidth=0.6)
    ax2.set_xticks(x); ax2.set_xticklabels(disp_lbls, fontsize=7.5)
    ax2.set_ylabel("Seconds"); ax2.set_title("P95 Latency")
    for bar, val in zip(bars2, p95s):
        ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
                 f"{val:.2f}s", ha="center", va="bottom", color="white", fontsize=8)

    # ── 3. Peak GPU memory ───────────────────────────────────────────────────
    ax3 = styled_ax(fig.add_subplot(gs[0, 2]))
    bars3 = ax3.bar(x, mems, bar_w, color=colors, edgecolor="white", linewidth=0.6)
    ax3.set_xticks(x); ax3.set_xticklabels(disp_lbls, fontsize=7.5)
    ax3.set_ylabel("GB"); ax3.set_title("Peak GPU Memory")
    for bar, val in zip(bars3, mems):
        ax3.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.03,
                 f"{val:.2f}GB", ha="center", va="bottom", color="white", fontsize=8)

    # ── 4. Per-run latency line chart ────────────────────────────────────────
    ax4 = styled_ax(fig.add_subplot(gs[1, :2]))
    run_idx = np.arange(1, len(results[configs[0]]["times_s"]) + 1)
    for cfg in configs:
        ts = results[cfg]["times_s"]
        ax4.plot(run_idx, ts, marker="o", color=PALETTE[cfg],
                 label=LABELS[cfg].replace("\n", " "), linewidth=1.8, markersize=5)
    ax4.set_xlabel("Run #"); ax4.set_ylabel("Seconds")
    ax4.set_title("Per-Run Latency")
    ax4.legend(fontsize=7.5, framealpha=0.2, labelcolor="white",
               facecolor="#1A1D27", edgecolor="#3A3D4A")
    ax4.set_xticks(run_idx)

    # ── 5. Speedup vs. reference ─────────────────────────────────────────────
    ax5 = styled_ax(fig.add_subplot(gs[1, 2]))
    bar_colors_su = ["#55A868" if s >= 1.0 else "#C44E52" for s in speedups]
    bars5 = ax5.bar(x, speedups, bar_w, color=bar_colors_su, edgecolor="white", linewidth=0.6)
    ax5.axhline(1.0, color="white", linestyle="--", linewidth=0.9, alpha=0.5)
    ax5.set_xticks(x); ax5.set_xticklabels(disp_lbls, fontsize=7.5)
    ax5.set_ylabel("×"); ax5.set_title(f"Speedup vs. '{ref_key}'")
    for bar, val in zip(bars5, speedups):
        ax5.text(bar.get_x() + bar.get_width() / 2,
                 bar.get_height() + 0.01 if val >= 0 else bar.get_height() - 0.05,
                 f"{val:.2f}×", ha="center", va="bottom", color="white", fontsize=8)

    # ── 6 & 7. Generated image samples + diff heatmap ────────────────────────
    # Show the first generated image for each config in a row
    n_imgs = min(len(configs), 3)
    for i, cfg in enumerate(configs[:n_imgs]):
        ax_img = styled_ax(fig.add_subplot(gs[2, i]))
        img = results[cfg]["images"][0]
        ax_img.imshow(img)
        ax_img.set_title(f"{LABELS[cfg].replace(chr(10), ' ')}\nsample output",
                         fontsize=8)
        ax_img.axis("off")

    # If int8 is present, overlay mean-diff heatmap in the last panel
    if "int8" in results and n_imgs < 3:
        ref_img  = results[ref_key]["images"][0]
        int8_img = results["int8"]["images"][0]
        diff     = image_diff(ref_img, int8_img).mean(axis=-1)
        ax_diff  = styled_ax(fig.add_subplot(gs[2, n_imgs]))
        im = ax_diff.imshow(diff, cmap="inferno")
        ax_diff.set_title(f"Pixel diff\n(INT8 vs {ref_key})", fontsize=8)
        ax_diff.axis("off")
        cbar = fig.colorbar(im, ax=ax_diff, fraction=0.046, pad=0.04)
        cbar.ax.yaxis.set_tick_params(color="white")
        plt.setp(cbar.ax.yaxis.get_ticklabels(), color="white")

    out_path = out_dir / "benchmark_sana_int8.png"
    fig.savefig(out_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
    plt.close(fig)
    print(f"[plot] Saved → {out_path}")
    return out_path


def plot_timing_distribution(results: dict, out_dir: Path):
    """Violin / box-plot of run-time distributions."""
    configs = list(results.keys())
    data    = [results[c]["times_s"] for c in configs]
    labels  = [LABELS[c].replace("\n", " ") for c in configs]
    colors  = [PALETTE[c] for c in configs]

    fig, ax = plt.subplots(figsize=(9, 5), facecolor="#0F1117")
    ax.set_facecolor("#1A1D27")
    ax.tick_params(colors="white"); ax.xaxis.label.set_color("white")
    ax.yaxis.label.set_color("white"); ax.title.set_color("white")
    for sp in ax.spines.values():
        sp.set_edgecolor("#3A3D4A")

    vp = ax.violinplot(data, positions=range(len(configs)), showmedians=True,
                       showextrema=True)
    for i, (body, col) in enumerate(zip(vp["bodies"], colors)):
        body.set_facecolor(col); body.set_alpha(0.7)
    vp["cmedians"].set_color("white")
    vp["cmins"].set_color("white"); vp["cmaxes"].set_color("white")
    vp["cbars"].set_color("white")

    ax.set_xticks(range(len(configs)))
    ax.set_xticklabels(labels, fontsize=8)
    ax.set_ylabel("Seconds")
    ax.set_title("Latency Distribution (violin)")
    fig.tight_layout()

    out_path = out_dir / "timing_distribution.png"
    fig.savefig(out_path, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
    plt.close(fig)
    print(f"[plot] Saved → {out_path}")
    return out_path


# ─────────────────────────────────────────────────────────────────────────────
# Numeric summary
# ─────────────────────────────────────────────────────────────────────────────

def print_summary(results: dict, ref_key: str):
    print("\n" + "=" * 72)
    print(f"{'Config':<22} {'Mean':>8} {'P50':>8} {'P95':>8} {'Std':>8} {'Mem(GB)':>9} {'Speedup':>9}")
    print("-" * 72)
    ref_mean = results[ref_key]["times_s"].mean()
    for cfg, res in results.items():
        ts = res["times_s"]
        su = ref_mean / ts.mean()
        print(
            f"{cfg:<22} {ts.mean():>7.3f}s {np.median(ts):>7.3f}s "
            f"{np.percentile(ts, 95):>7.3f}s {ts.std():>7.3f}s "
            f"{res['peak_mem_gb']:>8.2f}  {su:>8.2f}×"
        )
    print("=" * 72)
    if "int8" in results:
        su = ref_mean / results["int8"]["times_s"].mean()
        tag = "faster" if su > 1 else "slower"
        print(f"\n  INT8 kernel is {abs(su):.2f}× {tag} than '{ref_key}' on average.\n")


# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────

def main():
    args = parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
    dtype     = dtype_map[args.dtype]

    print(f"[init] Loading pipeline: {args.model_id}  dtype={args.dtype}  device={args.device}")
    pipe = load_pipeline(args.model_id, dtype, args.device)
    gpu  = gpu_name()
    print(f"[init] GPU: {gpu}")

    # ── Load INT8 kernel ──────────────────────────────────────────────────────
    from kernels import get_local_kernel
    kernel_path = Path(args.kernel_path)
    print(f"[init] Loading kernel from {kernel_path}")
    kernel_mod = get_local_kernel(kernel_path, args.package_name)
    kernel_fn  = kernel_mod.int8_attention   # C++ symbol

    # ── Benchmark configurations ──────────────────────────────────────────────
    results = {}

    # 1. Default (LinearAttn self-attn + SDPA cross-attn)
    print("\n[bench] Configuration: default (original model processors)")
    restore_default(pipe)
    results["default"] = run_benchmark(
        pipe, args.prompt, args.steps, args.runs, args.warmup, args.device
    )

    # 2. SDPA everywhere (O(N²) fair baseline)
    print("\n[bench] Configuration: sdpa_all (F.scaled_dot_product_attention everywhere)")
    set_sdpa_all(pipe)
    results["sdpa_all"] = run_benchmark(
        pipe, args.prompt, args.steps, args.runs, args.warmup, args.device
    )

    # 3. INT8 kernel for self-attention
    print("\n[bench] Configuration: int8 (INT8 kernel for self-attn, SDPA for cross-attn)")
    set_int8(pipe, kernel_fn)
    results["int8"] = run_benchmark(
        pipe, args.prompt, args.steps, args.runs, args.warmup, args.device
    )

    # ── Summary ───────────────────────────────────────────────────────────────
    ref_key = "sdpa_all"
    print_summary(results, ref_key)

    # ── Accuracy metrics ──────────────────────────────────────────────────────
    if "sdpa_all" in results and "int8" in results:
        diff = image_diff(results["sdpa_all"]["images"][0], results["int8"]["images"][0])
        print(f"\n[accuracy] INT8 vs SDPA_all — mean pixel diff: {diff.mean():.4f}  "
              f"max: {diff.max():.1f}  (range 0–255)")

    # ── Plots ─────────────────────────────────────────────────────────────────
    main_plot = plot_results(results, out_dir, gpu, args.steps)
    dist_plot = plot_timing_distribution(results, out_dir)

    # Save individual generated images
    for cfg, res in results.items():
        img_path = out_dir / f"sample_{cfg}.png"
        res["images"][0].save(img_path)
        print(f"[image] Saved → {img_path}")

    print(f"\n[done] All results in {out_dir.resolve()}")


if __name__ == "__main__":
    main()

@ParagEkbote
Copy link
Copy Markdown
Owner

Secondly, since we would be pushing this corrected kernel to HF in order for be used with get_kernel method, take a look at the Kernel Requirements page that states how to set the kernel directory for a kernel on the HF and also explore an example kernel:kernels-community/flash-attn2 in order to push the builds to hf hub in a manner that a good set of users can use it. WDYT?

cc: @florianmattana

@florianmattana
Copy link
Copy Markdown
Contributor Author

Hi @ParagEkbote , there are couple of issues with the script slowing down the test result. I ill do my best to fix it asap

@florianmattana
Copy link
Copy Markdown
Contributor Author

Hey Parag, the timestep_scales fix didn't resolve the segfault. I added some debug prints and the kernel is receiving valid inputs:

Q: torch.Size([2, 70, 1024, 32]) float16 cuda:0
timestep_scales: torch.Size([1]) float32 cuda:0

HEAD_DIM=32 is in the supported list, shapes look correct. The crash happens inside the kernel itself, likely a race condition in the online softmax accumulator — out_acc is written by multiple warps without explicit synchronization between them.

This needs a fix in attention_int8.cu before we can benchmark. Let me know if you want me to dig into it or if you'd prefer to handle it on your end.

@ParagEkbote
Copy link
Copy Markdown
Owner

Hey Parag, the timestep_scales fix didn't resolve the segfault. I added some debug prints and the kernel is receiving valid inputs:

Q: torch.Size([2, 70, 1024, 32]) float16 cuda:0 timestep_scales: torch.Size([1]) float32 cuda:0

HEAD_DIM=32 is in the supported list, shapes look correct. The crash happens inside the kernel itself, likely a race condition in the online softmax accumulator — out_acc is written by multiple warps without explicit synchronization between them.

This needs a fix in attention_int8.cu before we can benchmark. Let me know if you want me to dig into it or if you'd prefer to handle it on your end.

Feel free to take a closer look and open a PR, the goal of publishing the correct kernel to HF Hub and the Medium article is not changed. I will also try to take a look if I can correct the issue.

cc: @florianmattana

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants