Experimental WebGPU backend for PyTorch, which can compile and run LLMs on WebGPU!
12.01.2026 torch-webgpu reached 0.0.1
Now supported:
- Run PyTorch on WebGPU
device="webgpu" - Compile PyTorch code for WebGPU -
@torch.compile(m, backend=webgpu) - Many standard PyTorch operations are supported
Next steps:
- Compiler optimizations
- High performance without platform specific (CUDA, MPS, ROCm) kernels. Five ingredients are enough to get there - PyTorch, Python, C++, WGSL shaders and WebGPU runtime. Currently,
torch-webpguuses Google Dawn - Implement missing ops
Compile and run a real LLM on WebGPU: Qwen/Qwen2.5-0.5B-Instruct with @torch.compile(backend=webgpu_backend)!
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch_webgpu.compiler.webgpu_compiler import webgpu_backend
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", torch_dtype=torch.float32
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model.eval()
compiled_model = torch.compile(model, backend=webgpu_backend)
with torch.no_grad():
inputs = tokenizer("Hello, how are you?", return_tensors="pt")
input_ids = inputs["input_ids"]
generated_ids = input_ids.clone()
outputs = compiled_model(input_ids)
for _ in range(10):
outputs = compiled_model(generated_ids)
next_token = outputs.logits[0, -1].argmax().unsqueeze(0).unsqueeze(0)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))Run tests: pytest tests/test_qwen_compile.py -s
pip install torch-webgpuSupported platforms:
- Linux (x86_64)
- macOS (arm64)
- Windows (x86_64)
- Clone this repo
- Build Dawn:
./scripts/build-dawn.sh(or setDAWN_PREFIXto your Dawn installation) - Build:
./build.sh
In Python:
import torch_webgpu
And now you can use device="webgpu" and to="webgpu" to run pytorch on a real webgpu!
WebGPU promises to run everywhere - on every hardware and becomes well supported in web browser. This project is a bridge between PyTorch world and WebGPU world
This is a step towards running PyTorch in a browser. The next step is to run PyTorch inside a browser. I am actively researching how to do it - if this topic excites you too, contact me on Twitter or open an Issue in this GitHub repo
How serious are you about this project? Is it a research or PoC in mind or are you going to make it production quality?
Once we hit version 1.0.0, torch-webgpu will be a production-ready PyTorch backend. WebGPU is an exciting, emerging technology. As of Nov 2025 all major browsers support WebGPU. I think that it's highly important to build a bridge between PyTorch and WebGPU.
We'll see, ideally I'd see it as a part of PyTorch core, but we need to get a very high quality first to allow ourselves to ask PyTorch maintainers about it
I have a very little time and need to be picky about contributions, so please make sure you contribute code that is:
- well thought
- covered with unit tests
- you understand everything what you wrote
- as concise as possible - I can't handle too big PRs, sorry!
Use LLM at your discretion, but provide exhaustive explanation of what you built and why. Write it by yourself to show that you really understand
I can understand if that sounds too picky, but since I build this project after hours, I need to cut any additional noise. Sorry and thanks for understanding!
That's ok. The main goal here is to build a bridge (for community) and learn ML compilers in depth (for me). The project moves regularly, at its own pace. Things improve, cover more use cases, get more tests, get rethinked and rewrote. A journey, insights and learning over a raw development velocity. That's a tradeoff I choose
You can fund the project to give me more spare time to work on it. My email: github@maczan.pl
The project started 26 Oct 2025. I have been coding it by hand and learning a lot about PyTorch internals and ML compilation in general. Once I made the project to the point where you could compile and run MLP on WebGPU, on 10 Jan 2026 I started to generate many missing ops using AI agents. In just 2 days, AI boosted the project from compiling and running MLPs to compiling and running LLMs ❤️
Most of the important ops are implemented. If any is missing, feel free to open a PR or an issue. Thanks!
- CPU <-> WebGPU
- CUDA <-> WebGPU
- MPS <-> WebGPU
- Intel Gaudi <-> WebGPU
- XLA <-> WebGPU
- performance wasn't a priority yet
- only float32 supported
wgpu::Queue.Submit()handled synchronously- some ops might fallback to CPU
- Ascend's NPU backend for PyTorch https://github.com/ascend/pytorch
- Elie's WebGPU guide https://eliemichel.github.io/LearnWebGPU/index.html
- WGSL spec https://www.w3.org/TR/WGSL/
- PyTorch PrivateUse1 custom backend docs as a reference https://docs.pytorch.org/tutorials/advanced/privateuseone.html https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html https://docs.pytorch.org/tutorials/advanced/dispatcher
- https://www.nuss-and-bolts.com/p/optimizing-a-webgpu-matmul-kernel
- https://webgpufundamentals.org/webgpu/lessons/webgpu-compute-shaders.html
Note: This project is unrelated to webgpu-torch, which is a neat PyTorch reimplementation in TypeScript targeting WebGPU
- Remember to rebuild your code before testing -
./build.sh chmod +x build-ctests.sh run-ctests.sh- Update
build-ctests.shwith your paths rm -rf build/ctests./build-ctests.sh./run-ctests.sh
- Remember to rebuild your code before testing -
./build.shand optionally log in to your wandb.ai account chmod +x build-benchmark.sh run-benchmark.sh- Update
build-benchmark.shwith your paths rm -rf build/benchmarks./build-benchmark.sh./run-benchmark.sh
- Remember to rebuild your code before testing -
./build.sh pytest teststo run all tests.pytest tests/ops/test_cos.pyto run a chosen test file, like here we test cosinus
- Update version in
pyproject.toml - Create and push a git tag:
git tag v0.0.2 && git push origin v0.0.2 - Create a GitHub Release from the tag
- The
publish.ymlworkflow will automatically build wheels and publish to PyPI
For testing: Use "Actions" > "Publish to PyPI" > "Run workflow" > select "testpypi"
If you use this software, please cite it as below.
@software{Maczan_torch-webgpu_2025,
author = {Maczan, Jędrzej Paweł},
month = oct,
title = {{torch-webgpu - PyTorch compiler and WebGPU runtime}},
url = {https://github.com/jmaczan/torch-webgpu},
version = {1.0.0},
year = {2025}
}