Skip to content

jmaczan/torch-webgpu

Repository files navigation

torch-webgpu

Experimental WebGPU backend for PyTorch, which can compile and run LLMs on WebGPU!

12.01.2026 torch-webgpu reached 0.0.1

Now supported:

  1. Run PyTorch on WebGPU device="webgpu"
  2. Compile PyTorch code for WebGPU - @torch.compile(m, backend=webgpu)
  3. Many standard PyTorch operations are supported

Next steps:

  1. Compiler optimizations
  2. High performance without platform specific (CUDA, MPS, ROCm) kernels. Five ingredients are enough to get there - PyTorch, Python, C++, WGSL shaders and WebGPU runtime. Currently, torch-webpgu uses Google Dawn
  3. Implement missing ops

WebGPU logo by W3C

Coolest thing you can do with torch-webgpu now

Compile and run a real LLM on WebGPU: Qwen/Qwen2.5-0.5B-Instruct with @torch.compile(backend=webgpu_backend)!

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch_webgpu.compiler.webgpu_compiler import webgpu_backend

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct", torch_dtype=torch.float32
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model.eval()

compiled_model = torch.compile(model, backend=webgpu_backend)

with torch.no_grad():
    inputs = tokenizer("Hello, how are you?", return_tensors="pt")
    input_ids = inputs["input_ids"]
    generated_ids = input_ids.clone()
    outputs = compiled_model(input_ids)
    for _ in range(10):
        outputs = compiled_model(generated_ids)
        next_token = outputs.logits[0, -1].argmax().unsqueeze(0).unsqueeze(0)
        generated_ids = torch.cat([generated_ids, next_token], dim=1)
    print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

Run tests: pytest tests/test_qwen_compile.py -s

Installation

pip install torch-webgpu

Supported platforms:

  • Linux (x86_64)
  • macOS (arm64)
  • Windows (x86_64)

From source (for development)

  1. Clone this repo
  2. Build Dawn: ./scripts/build-dawn.sh (or set DAWN_PREFIX to your Dawn installation)
  3. Build: ./build.sh

Use

In Python:

import torch_webgpu

And now you can use device="webgpu" and to="webgpu" to run pytorch on a real webgpu!

FAQ

Why?

WebGPU promises to run everywhere - on every hardware and becomes well supported in web browser. This project is a bridge between PyTorch world and WebGPU world

There is "web" in "WebGPU", so does it mean that I can run PyTorch in a browser now?

This is a step towards running PyTorch in a browser. The next step is to run PyTorch inside a browser. I am actively researching how to do it - if this topic excites you too, contact me on Twitter or open an Issue in this GitHub repo

How serious are you about this project? Is it a research or PoC in mind or are you going to make it production quality?

Once we hit version 1.0.0, torch-webgpu will be a production-ready PyTorch backend. WebGPU is an exciting, emerging technology. As of Nov 2025 all major browsers support WebGPU. I think that it's highly important to build a bridge between PyTorch and WebGPU.

Will you upstream WebGPU backend to PyTorch or keep it out-of-tree forever?

We'll see, ideally I'd see it as a part of PyTorch core, but we need to get a very high quality first to allow ourselves to ask PyTorch maintainers about it

Contributor policy

I have a very little time and need to be picky about contributions, so please make sure you contribute code that is:

  • well thought
  • covered with unit tests
  • you understand everything what you wrote
  • as concise as possible - I can't handle too big PRs, sorry!

Use LLM at your discretion, but provide exhaustive explanation of what you built and why. Write it by yourself to show that you really understand

I can understand if that sounds too picky, but since I build this project after hours, I need to cut any additional noise. Sorry and thanks for understanding!

I don't like X about this project

That's ok. The main goal here is to build a bridge (for community) and learn ML compilers in depth (for me). The project moves regularly, at its own pace. Things improve, cover more use cases, get more tests, get rethinked and rewrote. A journey, insights and learning over a raw development velocity. That's a tradeoff I choose

I wish you moved faster

You can fund the project to give me more spare time to work on it. My email: github@maczan.pl

Did AI built it?

The project started 26 Oct 2025. I have been coding it by hand and learning a lot about PyTorch internals and ML compilation in general. Once I made the project to the point where you could compile and run MLP on WebGPU, on 10 Jan 2026 I started to generate many missing ops using AI agents. In just 2 days, AI boosted the project from compiling and running MLPs to compiling and running LLMs ❤️

Open a GitHub issue if you have more questions. Thanks and let's build this bridge!

Ops support

Most of the important ops are implemented. If any is missing, feel free to open a PR or an issue. Thanks!

Device / to

  • CPU <-> WebGPU
  • CUDA <-> WebGPU
  • MPS <-> WebGPU
  • Intel Gaudi <-> WebGPU
  • XLA <-> WebGPU

Rough edges

  • performance wasn't a priority yet
  • only float32 supported
  • wgpu::Queue.Submit() handled synchronously
  • some ops might fallback to CPU

Resources

Note: This project is unrelated to webgpu-torch, which is a neat PyTorch reimplementation in TypeScript targeting WebGPU

Dev resources

C++ unit tests

  1. Remember to rebuild your code before testing - ./build.sh
  2. chmod +x build-ctests.sh run-ctests.sh
  3. Update build-ctests.sh with your paths
  4. rm -rf build/ctests
  5. ./build-ctests.sh
  6. ./run-ctests.sh

C++ benchmarks

  1. Remember to rebuild your code before testing - ./build.sh and optionally log in to your wandb.ai account
  2. chmod +x build-benchmark.sh run-benchmark.sh
  3. Update build-benchmark.sh with your paths
  4. rm -rf build/benchmarks
  5. ./build-benchmark.sh
  6. ./run-benchmark.sh

Python unit tests

  1. Remember to rebuild your code before testing - ./build.sh
  2. pytest tests to run all tests. pytest tests/ops/test_cos.py to run a chosen test file, like here we test cosinus

Releasing (for maintainers)

  1. Update version in pyproject.toml
  2. Create and push a git tag: git tag v0.0.2 && git push origin v0.0.2
  3. Create a GitHub Release from the tag
  4. The publish.yml workflow will automatically build wheels and publish to PyPI

For testing: Use "Actions" > "Publish to PyPI" > "Run workflow" > select "testpypi"

Cite

If you use this software, please cite it as below.

@software{Maczan_torch-webgpu_2025,
author = {Maczan, Jędrzej Paweł},
month = oct,
title = {{torch-webgpu - PyTorch compiler and WebGPU runtime}},
url = {https://github.com/jmaczan/torch-webgpu},
version = {1.0.0},
year = {2025}
}

Credits

Jędrzej Maczan, 2025 - ∞