diff --git a/.cargo/config.toml b/.cargo/config.toml
index ca9d853b60..fb5b664ed6 100644
--- a/.cargo/config.toml
+++ b/.cargo/config.toml
@@ -2,7 +2,10 @@
rustflags = ["-C", "target-cpu=native"]
[target.wasm32-unknown-unknown]
-rustflags = ["-C", "target-feature=+simd128"]
+rustflags = ["-C", "target-feature=+simd128", "--cfg", 'getrandom_backend="wasm_js"']
[target.x86_64-apple-darwin]
-rustflags = ["-C", "target-feature=-avx,-avx2"]
\ No newline at end of file
+rustflags = ["-C", "target-feature=-avx,-avx2"]
+
+[alias]
+xtask = "run --manifest-path ../xtask/Cargo.toml --"
\ No newline at end of file
diff --git a/.github/workflows/book-cd.yml b/.github/workflows/book-cd.yml
deleted file mode 100644
index e8149e3832..0000000000
--- a/.github/workflows/book-cd.yml
+++ /dev/null
@@ -1,40 +0,0 @@
-name: Deploy Rust book
-on:
- push:
- branches:
- - main
-
-jobs:
- deploy:
- runs-on: ubuntu-latest
- permissions:
- contents: write # To push a branch
- pull-requests: write # To create a PR from that branch
- steps:
- - uses: actions/checkout@v3
- with:
- fetch-depth: 0
- - name: Install latest mdbook
- run: |
- tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
- url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
- mkdir mdbook
- curl -sSL $url | tar -xz --directory=./mdbook
- echo `pwd`/mdbook >> $GITHUB_PATH
- - name: Deploy GitHub Pages
- run: |
- # This assumes your book is in the root of your repository.
- # Just add a `cd` here if you need to change to another directory.
- cd candle-book
- mdbook build
- git worktree add gh-pages
- git config user.name "Deploy from CI"
- git config user.email ""
- cd gh-pages
- # Delete the ref to avoid keeping history.
- git update-ref -d refs/heads/gh-pages
- rm -rf *
- mv ../book/* .
- git add .
- git commit -m "Deploy $GITHUB_SHA to gh-pages"
- git push --force --set-upstream origin gh-pages
diff --git a/.github/workflows/book.yml b/.github/workflows/book.yml
deleted file mode 100644
index bb4d0494fb..0000000000
--- a/.github/workflows/book.yml
+++ /dev/null
@@ -1,29 +0,0 @@
-name: CI
-on:
- pull_request:
-
-jobs:
- test:
- name: Test candle-book
- runs-on: ubuntu-latest
- permissions:
- contents: write # To push a branch
- pull-requests: write # To create a PR from that branch
- steps:
- - uses: actions/checkout@master
- - name: Install Rust
- run: |
- rustup set profile minimal
- rustup toolchain install stable
- rustup default stable
- - name: Install latest mdbook
- run: |
- tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
- url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
- mkdir bin
- curl -sSL $url | tar -xz --directory=bin
- echo "$(pwd)/bin" >> $GITHUB_PATH
- - name: Run tests
- run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/
-
-
diff --git a/.github/workflows/ci_cuda.yaml b/.github/workflows/ci_cuda.yaml
index fc07f112c7..b886d6fc3e 100644
--- a/.github/workflows/ci_cuda.yaml
+++ b/.github/workflows/ci_cuda.yaml
@@ -10,10 +10,9 @@ jobs:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
- group: aws-g4dn-2xlarge
+ group: aws-g5-4xlarge-cache
container:
- image: nvidia/cuda:12.3.1-devel-ubuntu22.04
- options: --gpus 0
+ image: nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
permissions:
contents: write
@@ -22,13 +21,15 @@ jobs:
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
+ env:
+ CUDA_COMPUTE_CAP: 86
steps:
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v6
- name: Install dependencies
- run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
+ run: apt update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
- name: Install Rust Stable
- uses: actions-rust-lang/setup-rust-toolchain@v1
+ uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Test (cuda)
run: cargo test --features cuda
diff --git a/.github/workflows/maturin.yml b/.github/workflows/maturin.yml
index 46bdb903da..d58fbdd616 100644
Binary files a/.github/workflows/maturin.yml and b/.github/workflows/maturin.yml differ
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 68e2eee31e..f8bf3ad002 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -20,21 +20,19 @@ jobs:
os: [ubuntu-latest] # For now, only test on Linux
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
- name: Install Rust
- uses: actions-rs/toolchain@v1
- with:
- toolchain: stable
+ uses: dtolnay/rust-toolchain@stable
- name: Install Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v6
with:
- python-version: 3.11
+ python-version: 3.13
architecture: "x64"
- name: Cache Cargo Registry
- uses: actions/cache@v1
+ uses: actions/cache@v5
with:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
@@ -42,8 +40,8 @@ jobs:
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
- version: "25.0"
- repo-token: ${{ secrets.GITHUB_TOKEN }}
+ version: "25.0"
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install
working-directory: ./candle-pyo3
diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml
index ee480c474c..375ea7d362 100644
--- a/.github/workflows/rust-ci.yml
+++ b/.github/workflows/rust-ci.yml
@@ -11,20 +11,38 @@ jobs:
name: Check
runs-on: ${{ matrix.os }}
strategy:
+ fail-fast: false
matrix:
- os: [ubuntu-latest, windows-latest, macOS-latest]
- rust: [stable]
+ os: [ubuntu-latest, ubuntu-24.04, windows-latest, macOS-latest, ubuntu-24.04-arm]
steps:
- - uses: actions/checkout@v4
- - uses: actions-rs/toolchain@v1
- with:
- profile: minimal
- toolchain: ${{ matrix.rust }}
- override: true
- - uses: actions-rs/cargo@v1
+ - uses: actions/checkout@v6
+ - uses: actions/setup-python@v6
with:
- command: check
- args: --workspace
+ python-version: "3.13"
+ - name: Remove cargo config (macOS ring crate fix)
+ if: runner.os == 'macOS'
+ run: rm -f .cargo/config.toml
+ - uses: dtolnay/rust-toolchain@stable
+
+ - name: Run macos with metal
+ if: matrix.os == 'macOS-latest'
+ run: cargo check --workspace --features metal
+
+ - name: Run normal cpu
+ if: matrix.os == 'ubuntu-latest' || matrix.os == 'windows-latest'
+ run: cargo check --workspace
+
+ - name: Run with avx2
+ if: matrix.os == 'ubuntu-24.04'
+ run: |
+ export RUSTFLAGS="-C target-feature=avx2"
+ cargo check --workspace
+
+ - name: Run with arm neon
+ if: matrix.os == 'ubuntu-24.04-arm'
+ run: |
+ export RUSTFLAGS="-C target-feature=neon"
+ cargo check --workspace
test:
name: Test Suite
@@ -32,47 +50,52 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
- rust: [stable]
steps:
- - uses: actions/checkout@v4
- - uses: actions-rs/toolchain@v1
- with:
- profile: minimal
- toolchain: ${{ matrix.rust }}
- override: true
- - uses: actions-rs/cargo@v1
+ - name: Free disk space (Linux)
+ if: runner.os == 'Linux'
+ run: |
+ sudo rm -rf /opt/hostedtoolcache
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /usr/local/lib/android
+ sudo rm -rf /opt/ghc
+ df -h
+ - uses: actions/checkout@v6
+ - uses: actions/setup-python@v6
with:
- command: test
- args: --workspace
+ python-version: "3.13"
+ - name: Remove cargo config (macOS ring crate fix)
+ if: runner.os == 'macOS'
+ run: rm -f .cargo/config.toml
+ - uses: dtolnay/rust-toolchain@stable
+ - name: Install lld (Linux only)
+ if: runner.os == 'Linux'
+ run: sudo apt-get update && sudo apt-get install -y lld
+ - name: Run tests (with lld on Linux)
+ if: runner.os == 'Linux'
+ env:
+ RUSTFLAGS: "-C link-arg=-fuse-ld=lld"
+ run: cargo test --workspace
+ - name: Run tests (Windows & macOS)
+ if: runner.os != 'Linux'
+ run: cargo test --workspace
fmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
- - uses: actions-rs/toolchain@v1
+ - uses: actions/checkout@v6
+ - uses: dtolnay/rust-toolchain@stable
with:
- profile: minimal
- toolchain: stable
- override: true
- - run: rustup component add rustfmt
- - uses: actions-rs/cargo@v1
- with:
- command: fmt
- args: --all -- --check
+ components: rustfmt
+ - run: cargo fmt --all -- --check
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
- - uses: actions-rs/toolchain@v1
- with:
- profile: minimal
- toolchain: stable
- override: true
- - run: rustup component add clippy
- - uses: actions-rs/cargo@v1
+ - uses: actions/checkout@v6
+ - uses: dtolnay/rust-toolchain@stable
with:
- command: clippy
- args: --workspace --tests --examples -- -D warnings
+ components: clippy
+ - run: cargo clippy --workspace --tests --examples --benches -- -D warnings
+
\ No newline at end of file
diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml
index 9cbbf68037..7ae4f55793 100644
--- a/.github/workflows/trufflehog.yml
+++ b/.github/workflows/trufflehog.yml
@@ -7,9 +7,9 @@ jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- - name: Secret Scanning
- uses: trufflesecurity/trufflehog@main
+ - name: Checkout code
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ - name: Secret Scanning
+ uses: trufflesecurity/trufflehog@main
diff --git a/.gitignore b/.gitignore
index 4dfbcc1663..b90dab0291 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,6 +12,7 @@ Cargo.lock
# editor config
.helix
.vscode
+.zed
# These are backup files generated by rustfmt
**/*.rs.bk
@@ -46,3 +47,11 @@ out.wav
bria.mp3
bria.safetensors
bria.wav
+
+#generated wgpu shader files
+**/generated/*.pwgsl_generated_*.wgsl
+wgpu_*_test_*_measurements.json
+wgpu_*_test_*_shaders.json
+wgpu_*_test_*_used_consts.json
+wgpu_*_test_*_used_pipelines.json
+wgpu_loader_indices.toml
diff --git a/.gitmodules b/.gitmodules
deleted file mode 100644
index 12631cbc27..0000000000
--- a/.gitmodules
+++ /dev/null
@@ -1,3 +0,0 @@
-[submodule "candle-examples/examples/flash-attn/cutlass"]
- path = candle-flash-attn/cutlass
- url = https://github.com/NVIDIA/cutlass.git
diff --git a/.vscode/settings.json b/.vscode/settings.json
index b2dbd68012..da1628fc5f 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -1,11 +1,5 @@
{
- "[python]": {
- "editor.defaultFormatter": "ms-python.black-formatter"
- },
- "python.formatting.provider": "none",
- "python.testing.pytestArgs": [
- "candle-pyo3"
- ],
- "python.testing.unittestEnabled": false,
- "python.testing.pytestEnabled": true
+ "files.associations": {
+ "*.pwgsl": "wgsl"
+ }
}
\ No newline at end of file
diff --git a/Cargo.toml b/Cargo.toml
index 17e7e4ba57..ac5a9a4330 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -3,24 +3,31 @@ members = [
"candle-core",
"candle-datasets",
"candle-examples",
- "candle-book",
"candle-nn",
"candle-pyo3",
"candle-transformers",
+ "candle-ug",
"candle-wasm-examples/*",
"candle-wasm-tests",
"tensor-tools",
+ "candle-wgpu-kernels",
+ "wgpu-compute-layer/wgpu-compute-layer-pwgsl",
+ "wgpu-compute-layer",
]
exclude = [
- "candle-flash-attn",
- "candle-kernels",
- "candle-metal-kernels",
- "candle-onnx",
+ "candle-book",
+ "candle-flash-attn-build",
+ "candle-flash-attn",
+ "candle-flash-attn-v3",
+ "candle-kernels",
+
+ "candle-metal-kernels",
+ "candle-onnx",
]
resolver = "2"
[workspace.package]
-version = "0.8.0"
+version = "0.9.2"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@@ -32,50 +39,88 @@ license = "MIT OR Apache-2.0"
ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
+async-lock = "3.4.2"
byteorder = "1.4.3"
-candle = { path = "./candle-core", package = "candle-core", version = "0.8.0" }
-candle-datasets = { path = "./candle-datasets", version = "0.8.0" }
-candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.0" }
-candle-kernels = { path = "./candle-kernels", version = "0.8.0" }
-candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.0" }
-candle-nn = { path = "./candle-nn", version = "0.8.0" }
-candle-onnx = { path = "./candle-onnx", version = "0.8.0" }
-candle-transformers = { path = "./candle-transformers", version = "0.8.0" }
+candle = { path = "./candle-core", package = "candle-core", version = "0.9.2" }
+candle-datasets = { path = "./candle-datasets", version = "0.9.2" }
+candle-flash-attn-build = { path = "candle-flash-attn-build", version = "0.9.2" }
+candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.2" }
+candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.9.2" }
+candle-kernels = { path = "./candle-kernels", version = "0.9.2" }
+candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.2" }
+candle-wgpu-kernels = { path = "./candle-wgpu-kernels", version = "0.9.2" }
+wgpu-compute-layer = { path = "./wgpu-compute-layer", version = "0.9.2" }
+candle-nn = { path = "./candle-nn", version = "0.9.2" }
+candle-onnx = { path = "./candle-onnx", version = "0.9.2" }
+candle-transformers = { path = "./candle-transformers", version = "0.9.2" }
+candle-ug = { path = "./candle-ug", version = "0.9.2" }
clap = { version = "4.2.4", features = ["derive"] }
-criterion = { version = "0.5.1", default-features=false }
-cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
-fancy-regex = "0.13.0"
-gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
-hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
-half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
+criterion = { version = "0.8", default-features = false }
+cudarc = { version = "0.19.1", features = [
+ "std",
+ "cublas",
+ "cublaslt",
+ "curand",
+ "driver",
+ "nvrtc",
+ "f16",
+ "f8",
+ "cuda-version-from-build-system",
+ "dynamic-linking",
+], default-features = false }
+fancy-regex = "0.17.0"
+gemm = { version = "0.19.0", features = ["wasm-simd128-enable"] }
+hf-hub = "0.4.1"
+half = { version = "2.5.0", features = [
+ "num-traits",
+ "use-intrinsics",
+ "rand_distr",
+] }
+float8 = { version = "0.7.0", features = ["num-traits", "rand_distr"] }
hound = "3.5.1"
-image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
-imageproc = { version = "0.24.0", default-features = false }
+image = { version = "0.25.2", default-features = false, features = [
+ "jpeg",
+ "png",
+] }
+imageproc = { version = "0.26.0", features = [
+ "text",
+], default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
+libm = { version = "0.2.15" }
log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
-parquet = { version = "51.0.0" }
-rand = "0.8.5"
-rand_distr = "0.4.3"
+parquet = "57"
+rand = "0.9.0"
+rand_distr = "0.5.1"
rayon = "1.7.0"
-safetensors = "0.4.1"
+safetensors = "0.7.0"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
-thiserror = "1"
-tokenizers = { version = "0.19.1", default-features = false }
+thiserror = "2"
+tokenizers = { version = "0.22.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
-ug = "0.0.2"
-ug-cuda = "0.0.2"
-ug-metal = "0.0.2"
-yoke = { version = "0.7.2", features = ["derive"] }
-zip = { version = "1.1.1", default-features = false }
-metal = { version = "0.27.0", features = ["mps"]}
+ug = "0.5.0"
+ug-cuda = "0.5.0"
+ug-metal = "0.5.0"
+yoke = { version = "0.8.1", features = ["derive"] }
+zip = { version = "7.2.0", default-features = false }
+objc2-metal = { version = "0.3.1" }
+objc2-foundation = { version = "0.3.1" }
+
+#wgpu dependencies:
+wgpu = {version="28.0.0", features=["fragile-send-sync-non-atomic-wasm"]}
+bytemuck = { version = "1.24.0", features = [ "derive" ] }
+pollster = "0.4.0"
+flume = "0.11.1"
+tracing-mutex = "0.3.2"
+web-time = "=1.1.0"
+rustc-hash = "2.1.1"
[profile.release-with-debug]
inherits = "release"
diff --git a/README.md b/README.md
index 246e2844ad..399a589e8b 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
# candle
-[](https://discord.gg/hugging-face-879548962464493619)
+[](https://discord.gg/hugging-face-879548962464493619)
[](https://crates.io/crates/candle-core)
[](https://docs.rs/candle-core)
[](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
@@ -59,12 +59,12 @@ These online demos run entirely in your browser:
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
-We also provide a some command line based examples using state of the art models:
+We also provide some command line based examples using state of the art models:
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
-- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
+- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion, code interpreter, web search, function calling, repository-level
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
@@ -92,6 +92,7 @@ We also provide a some command line based examples using state of the art models
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
+- [Quantized Qwen3 MoE](./candle-examples/examples/quantized-qwen3-moe/): support gguf quantized models of Qwen3 MoE models.
@@ -189,6 +190,8 @@ And then head over to
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
+- [`llms-from-scratch-rs`](https://github.com/nerdai/llms-from-scratch-rs): A comprehensive Rust translation of the code from Sebastian Raschka's Build an LLM from Scratch book.
+- [`vllm.rs`](https://github.com/guoqingbao/vllm.rs): A minimalist vLLM implementation in Rust based on Candle.
If you have an addition to this list, please submit a pull request.
@@ -204,6 +207,7 @@ If you have an addition to this list, please submit a pull request.
- Backends.
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
+ - Wgpu backend for execution on the GPU (e.g. if Cuda is not available or in the browser).
- WASM support, run your models in a browser.
- Included models.
- Language Models.
@@ -219,7 +223,7 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- - Qwen1.5, Qwen1.5 MoE.
+ - Qwen1.5, Qwen1.5 MoE, Qwen3 MoE.
- RWKV v5 and v6.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
@@ -227,6 +231,7 @@ If you have an addition to this list, please submit a pull request.
- Mixtral 8x7b.
- Zephyr 7b a and b (Mistral-7b based).
- OpenChat 3.5 (Mistral-7b based).
+ - Qwen3 MoE (16B-A3B, 32B-A3B)
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
@@ -280,6 +285,7 @@ Cheatsheet:
- [candle-nn](./candle-nn/): Tools to build real models
- [candle-examples](./candle-examples/): Examples of using the library in realistic settings
- [candle-kernels](./candle-kernels/): CUDA custom kernels
+- [candle-wgpu-kernels](./candle-wgpu-kernels/): wgpu custom kernels
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
@@ -289,6 +295,8 @@ Cheatsheet:
### Why should I use Candle?
+
+
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
binaries.
@@ -298,6 +306,7 @@ and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-fut
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
+
### Other ML frameworks
diff --git a/candle-book/CONTRIBUTING.md b/candle-book/CONTRIBUTING.md
new file mode 100644
index 0000000000..02120ec13d
--- /dev/null
+++ b/candle-book/CONTRIBUTING.md
@@ -0,0 +1,13 @@
+# Candle Book
+
+The book uses [mdBook](https://github.com/rust-lang/mdBook) for building.
+
+## Installation
+
+To install mdBook, run `cargo install mdbook`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/installation.html).
+
+## Viewing the book
+
+To view the book, run `mdbook serve --open candle-book`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/creating.html).
+
+The book is built automatically in github CI.
\ No newline at end of file
diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml
index dee55f2061..9c5ea2df3f 100644
--- a/candle-book/Cargo.toml
+++ b/candle-book/Cargo.toml
@@ -25,7 +25,7 @@ cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true, optional = true }
anyhow = { workspace = true }
-tokio = "1.29.1"
+tokio = "1.48.0"
[dev-dependencies]
byteorder = { workspace = true }
diff --git a/candle-book/src/README.md b/candle-book/src/README.md
index be352dc101..b7481b642c 100644
--- a/candle-book/src/README.md
+++ b/candle-book/src/README.md
@@ -1,6 +1,7 @@
# Introduction
-{{#include ../../README.md:features}}
+{{#include ../../README.md:goals}}
+{{#include ../../README.md:features}}
-This book will introduce step by step how to use `candle`.
+This book will introduce step by step how to use `candle`.
\ No newline at end of file
diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md
index 59831af26b..ac5ff7c173 100644
--- a/candle-book/src/SUMMARY.md
+++ b/candle-book/src/SUMMARY.md
@@ -5,7 +5,10 @@
# User Guide
- [Installation](guide/installation.md)
-- [Hello World - MNIST](guide/hello_world.md)
+- [Tutorial - MNIST](guide/mnist/intro.md)
+ - [Modeling](guide/mnist/modeling.md)
+ - [Training](guide/mnist/training.md)
+ - [Saving And Loading](guide/mnist/saving_loading.md)
- [PyTorch cheatsheet](guide/cheatsheet.md)
# Reference Guide
@@ -13,11 +16,17 @@
- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
+- [Tracing](tracing.md)
- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)
- [Fine-tuning]()
- [Serialization]()
+- [Wgpu Usage](wgpu/readme.md)
+ - [Custom Wgpu Shader](wgpu/custom_shader.md)
+ - [Wgpu Implementation Detail](wgpu/implementation_detail.md)
+ - [Benchmark Performance](wgpu/debug_performance.md)
+ - [Debug Shader logic](wgpu/debug_shader_logic.md)
- [Advanced Cuda usage]()
- [Writing a custom kernel]()
- [Porting a custom kernel]()
diff --git a/candle-book/src/guide/installation.md b/candle-book/src/guide/installation.md
index ca8b79680e..982a6d6058 100644
--- a/candle-book/src/guide/installation.md
+++ b/candle-book/src/guide/installation.md
@@ -1,8 +1,23 @@
# Installation
-**With Cuda support**:
+## 1. Create a new rust app or library
-1. First, make sure that Cuda is correctly installed.
+```bash
+cargo new myapp
+cd myapp
+```
+
+## 2. Add the correct candle version
+
+### Standard
+
+```bash
+cargo add --git https://github.com/huggingface/candle.git candle-core
+```
+
+### CUDA
+
+First, make sure that Cuda is correctly installed.
- `nvcc --version` should print information about your Cuda compiler driver.
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
like:
@@ -17,43 +32,41 @@ You can also compile the Cuda kernels for a specific compute cap using the
If any of the above commands errors out, please make sure to update your Cuda version.
-2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
-
-Start by creating a new cargo:
+Add the `candle-core` crate with the cuda feature:
```bash
-cargo new myapp
-cd myapp
+cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
```
-Make sure to add the `candle-core` crate with the cuda feature:
+### MKL
-```bash
-cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
-```
+You can also see the `mkl` feature which can get faster inference on CPU.
-Run `cargo build` to make sure everything can be correctly built.
+Add the `candle-core` crate with the mkl feature:
```bash
-cargo build
+cargo add --git https://github.com/huggingface/candle.git candle-core --features "mkl"
```
-**Without Cuda support**:
+### Metal
-Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
+Metal is exclusive to MacOS.
+
+Add the `candle-core` crate with the metal feature:
```bash
-cargo new myapp
-cd myapp
-cargo add --git https://github.com/huggingface/candle.git candle-core
+cargo add --git https://github.com/huggingface/candle.git candle-core --features "metal"
```
-Finally, run `cargo build` to make sure everything can be correctly built.
+## 3. Building
+
+Run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
-**With mkl support**
-You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
+**With wgpu support**
+
+You can also see the `wgpu` feature which could be interesting to get faster inference with (Vulkan, Dx12, Metal or WebGpu). [Using wgpu](../wgpu/)
diff --git a/candle-book/src/guide/mnist/intro.md b/candle-book/src/guide/mnist/intro.md
new file mode 100644
index 0000000000..06d56a1b2f
--- /dev/null
+++ b/candle-book/src/guide/mnist/intro.md
@@ -0,0 +1,17 @@
+# Candle MNIST Tutorial
+
+## Introduction
+
+This tutorial provides an introduction to Candle by implementing and training a neural network for MNIST digit classification from scratch.
+
+Throughout this tutorial, you will learn the basics of:
+
+- Tensor operations and model construction
+- Creating and implementing neural network layers
+- Parameter initialization
+- Training loop implementation
+- Saving and loading trained models
+
+## Getting Started
+
+Before proceeding, please ensure that you have properly installed Candle by following the instructions in the [Installation](../installation.md) guide.
\ No newline at end of file
diff --git a/candle-book/src/guide/mnist/modeling.md b/candle-book/src/guide/mnist/modeling.md
new file mode 100644
index 0000000000..f34e89a92f
--- /dev/null
+++ b/candle-book/src/guide/mnist/modeling.md
@@ -0,0 +1,172 @@
+# Candle MNIST Tutorial
+
+## Modeling
+
+Open `src/main.rs` in your project folder and insert the following code:
+
+```rust
+use candle_core::{Device, Result, Tensor};
+
+struct Model {
+ first: Tensor,
+ second: Tensor,
+}
+
+impl Model {
+ fn forward(&self, image: &Tensor) -> Result {
+ let x = image.matmul(&self.first)?;
+ let x = x.relu()?;
+ x.matmul(&self.second)
+ }
+}
+
+fn main() -> Result<()> {
+ // Use Device::new_cuda(0)?; to utilize GPU acceleration.
+ let device = Device::Cpu;
+
+ let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
+ let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
+ let model = Model { first, second };
+
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
+
+ let digit = model.forward(&dummy_image)?;
+ println!("Digit {digit:?} digit");
+ Ok(())
+}
+```
+
+Execute the program with:
+
+```bash
+$ cargo run --release
+
+> Digit Tensor[dims 1, 10; f32] digit
+```
+
+Since random inputs are provided, expect an incoherent output.
+
+## Implementing a `Linear` Layer
+
+To create a more sophisticated layer type, add a `bias` to the weight to construct the standard `Linear` layer.
+
+Replace the entire content of `src/main.rs` with:
+
+```rust
+use candle_core::{Device, Result, Tensor};
+
+struct Linear {
+ weight: Tensor,
+ bias: Tensor,
+}
+
+impl Linear {
+ fn forward(&self, x: &Tensor) -> Result {
+ let x = x.matmul(&self.weight)?;
+ x.broadcast_add(&self.bias)
+ }
+}
+
+struct Model {
+ first: Linear,
+ second: Linear,
+}
+
+impl Model {
+ fn forward(&self, image: &Tensor) -> Result {
+ let x = self.first.forward(image)?;
+ let x = x.relu()?;
+ self.second.forward(&x)
+ }
+}
+
+fn main() -> Result<()> {
+ // Use Device::new_cuda(0)?; for GPU acceleration.
+ // Use Device::Cpu; for CPU computation.
+ let device = Device::cuda_if_available(0)?;
+
+ // Initialize model parameters
+ let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
+ let first = Linear { weight, bias };
+ let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
+ let second = Linear { weight, bias };
+ let model = Model { first, second };
+
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
+
+ // Perform inference
+ let digit = model.forward(&dummy_image)?;
+ println!("Digit {digit:?} digit");
+ Ok(())
+}
+```
+
+Execute again with:
+
+```bash
+$ cargo run --release
+
+> Digit Tensor[dims 1, 10; f32] digit
+```
+
+## Utilizing `candle_nn`
+
+Many classical layers (such as [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs)) are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
+
+This `Linear` implementation follows PyTorch conventions for improved compatibility with existing models, utilizing the transpose of weights rather than direct weights.
+
+Let's simplify our implementation. First, add `candle-nn` as a dependency:
+
+```bash
+$ cargo add --git https://github.com/huggingface/candle.git candle-nn
+```
+
+Now, replace the entire content of `src/main.rs` with:
+
+```rust
+use candle_core::{Device, Result, Tensor};
+use candle_nn::{Linear, Module};
+
+struct Model {
+ first: Linear,
+ second: Linear,
+}
+
+impl Model {
+ fn forward(&self, image: &Tensor) -> Result {
+ let x = self.first.forward(image)?;
+ let x = x.relu()?;
+ self.second.forward(&x)
+ }
+}
+
+fn main() -> Result<()> {
+ // Use Device::new_cuda(0)?; for GPU acceleration.
+ let device = Device::Cpu;
+
+ // Note the dimension change: (784, 100) -> (100, 784)
+ let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
+ let first = Linear::new(weight, Some(bias));
+ let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
+ let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
+ let second = Linear::new(weight, Some(bias));
+ let model = Model { first, second };
+
+ let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
+
+ let digit = model.forward(&dummy_image)?;
+ println!("Digit {digit:?} digit");
+ Ok(())
+}
+```
+
+Execute the final version:
+
+```bash
+$ cargo run --release
+
+> Digit Tensor[dims 1, 10; f32] digit
+```
\ No newline at end of file
diff --git a/candle-book/src/guide/mnist/saving_loading.md b/candle-book/src/guide/mnist/saving_loading.md
new file mode 100644
index 0000000000..4511f068e0
--- /dev/null
+++ b/candle-book/src/guide/mnist/saving_loading.md
@@ -0,0 +1,158 @@
+# Candle MNIST Tutorial
+
+## Saving and Loading Models
+
+After training a model, it is useful to save and subsequently load the model parameters. In Candle, this functionality is managed through the `VarMap` data structure, with parameters stored on disk using the [safetensors](https://huggingface.co/docs/safetensors/index) format.
+
+### Saving Model Parameters
+
+Let's modify our `training_loop` function to include functionality for saving weights:
+
+```rust
+fn training_loop(
+ m: candle_datasets::vision::Dataset,
+) -> anyhow::Result<()> {
+ let dev = Device::cuda_if_available(0)?;
+
+ let train_labels = m.train_labels;
+ let train_images = m.train_images.to_device(&dev)?;
+ let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
+
+ // Initialize a VarMap for trainable parameters
+ let varmap = VarMap::new();
+ let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
+ let model = Model::new(vs.clone())?;
+
+ let learning_rate = 0.05;
+ let epochs = 10;
+
+ // Initialize stochastic gradient descent optimizer
+ let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
+ let test_images = m.test_images.to_device(&dev)?;
+ let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
+
+ for epoch in 1..epochs {
+ // Standard MNIST forward pass
+ let logits = model.forward(&train_images)?;
+ let log_sm = ops::log_softmax(&logits, D::Minus1)?;
+
+ // Compute Negative Log Likelihood loss
+ let loss = loss::nll(&log_sm, &train_labels)?;
+
+ // Perform backward pass and update weights
+ sgd.backward_step(&loss)?;
+
+ // Evaluate model on test set
+ let test_logits = model.forward(&test_images)?;
+ let sum_ok = test_logits
+ .argmax(D::Minus1)?
+ .eq(&test_labels)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_scalar::()?;
+ let test_accuracy = sum_ok / test_labels.dims1()? as f32;
+ println!(
+ "{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
+ loss.to_scalar::()?,
+ test_accuracy
+ );
+ }
+
+ // Save model weights to disk
+ varmap.save("model_weights.safetensors")?;
+ Ok(())
+}
+```
+
+```bash
+$ cargo run --release
+
+> 1 train loss: 2.40485 test acc: 0.11%
+> 2 train loss: 2.34161 test acc: 0.14%
+> 3 train loss: 2.28841 test acc: 0.17%
+> 4 train loss: 2.24158 test acc: 0.19%
+> 5 train loss: 2.19898 test acc: 0.23%
+> 6 train loss: 2.15927 test acc: 0.26%
+> 7 train loss: 2.12161 test acc: 0.29%
+> 8 train loss: 2.08549 test acc: 0.32%
+> 9 train loss: 2.05053 test acc: 0.35%
+```
+
+### Loading Model Parameters
+
+Now that we have saved our model parameters, we can modify the code to load them. The primary change required is to make the `varmap` variable mutable:
+
+```rust
+fn training_loop(
+ m: candle_datasets::vision::Dataset,
+) -> anyhow::Result<()> {
+ let dev = Device::cuda_if_available(0)?;
+
+ let train_labels = m.train_labels;
+ let train_images = m.train_images.to_device(&dev)?;
+ let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
+
+ // Create a mutable VarMap for trainable parameters
+ let mut varmap = VarMap::new();
+ let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
+ let model = Model::new(vs.clone())?;
+
+ // Load pre-trained weights from file
+ varmap.load("model_weights.safetensors")?;
+
+ let learning_rate = 0.05;
+ let epochs = 10;
+
+ // Initialize stochastic gradient descent optimizer
+ let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
+ let test_images = m.test_images.to_device(&dev)?;
+ let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
+
+ for epoch in 1..epochs {
+ // Standard MNIST forward pass
+ let logits = model.forward(&train_images)?;
+ let log_sm = ops::log_softmax(&logits, D::Minus1)?;
+
+ // Compute Negative Log Likelihood loss
+ let loss = loss::nll(&log_sm, &train_labels)?;
+
+ // Perform backward pass and update weights
+ sgd.backward_step(&loss)?;
+
+ // Evaluate model on test set
+ let test_logits = model.forward(&test_images)?;
+ let sum_ok = test_logits
+ .argmax(D::Minus1)?
+ .eq(&test_labels)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_scalar::()?;
+ let test_accuracy = sum_ok / test_labels.dims1()? as f32;
+ println!(
+ "{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
+ loss.to_scalar::()?,
+ test_accuracy
+ );
+ }
+
+ // Save updated weights back to disk
+ varmap.save("model_weights.safetensors")?;
+ Ok(())
+}
+```
+
+```bash
+$ cargo run --release
+
+> 1 train loss: 2.01645 test acc: 0.38%
+> 2 train loss: 1.98300 test acc: 0.41%
+> 3 train loss: 1.95008 test acc: 0.44%
+> 4 train loss: 1.91754 test acc: 0.47%
+> 5 train loss: 1.88534 test acc: 0.50%
+> 6 train loss: 1.85349 test acc: 0.53%
+> 7 train loss: 1.82198 test acc: 0.56%
+> 8 train loss: 1.79077 test acc: 0.59%
+> 9 train loss: 1.75989 test acc: 0.61%
+```
+
+Note that loading the weights will fail if the specified file does not exist or is incompatible with the current model architecture. Implementing file existence checks and appropriate error handling is left to the user.
\ No newline at end of file
diff --git a/candle-book/src/guide/mnist/training.md b/candle-book/src/guide/mnist/training.md
new file mode 100644
index 0000000000..054806955f
--- /dev/null
+++ b/candle-book/src/guide/mnist/training.md
@@ -0,0 +1,134 @@
+# Candle MNIST Tutorial
+
+## Training Implementation
+
+First, let's create a utility function `make_linear` that accepts a `VarBuilder` and returns an initialized linear layer. The `VarBuilder` constructs a `VarMap`, which is the data structure that stores our trainable parameters.
+
+```rust
+use candle_core::{Device, Result, Tensor};
+use candle_nn::{Linear, Module, VarBuilder, VarMap};
+
+fn make_linear(vs: VarBuilder, in_dim: usize, out_dim: usize) -> Result {
+ let ws = vs.get_with_hints(
+ (out_dim, in_dim),
+ "weight",
+ candle_nn::init::DEFAULT_KAIMING_NORMAL,
+ )?;
+ let bound = 1. / (in_dim as f64).sqrt();
+ let bs = vs.get_with_hints(
+ out_dim,
+ "bias",
+ candle_nn::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ },
+ )?;
+ Ok(Linear::new(ws, Some(bs)))
+}
+```
+
+Next, let's implement a `new` method for our model class to accept a `VarBuilder` and initialize the model. We use `VarBuilder::pp` to "push prefix" so that the parameter names are organized hierarchically: the first layer weights as `first.weight` and `first.bias`, and the second layer weights as `second.weight` and `second.bias`.
+
+```rust
+impl Model {
+ fn new(vs: VarBuilder) -> Result {
+ const IMAGE_DIM: usize = 784;
+ const HIDDEN_DIM: usize = 100;
+ const LABELS: usize = 10;
+
+ let first = make_linear(vs.pp("first"), IMAGE_DIM, HIDDEN_DIM)?;
+ let second = make_linear(vs.pp("second"), HIDDEN_DIM, LABELS)?;
+
+ Ok(Self { first, second })
+ }
+
+ fn forward(&self, image: &Tensor) -> Result {
+ let x = self.first.forward(image)?;
+ let x = x.relu()?;
+ self.second.forward(&x)
+ }
+}
+```
+
+Now, let's add the `candle-datasets` package to our project to access the MNIST dataset:
+
+```bash
+$ cargo add --git https://github.com/huggingface/candle.git candle-datasets
+```
+
+With the dataset available, we can implement our training loop:
+
+```rust
+use candle_core::{DType, Device, Result, Tensor, D};
+use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
+
+fn training_loop(
+ m: candle_datasets::vision::Dataset,
+) -> anyhow::Result<()> {
+ let dev = Device::cuda_if_available(0)?;
+
+ let train_labels = m.train_labels;
+ let train_images = m.train_images.to_device(&dev)?;
+ let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
+
+ // Initialize a VarMap to store trainable parameters
+ let varmap = VarMap::new();
+ let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
+ let model = Model::new(vs.clone())?;
+
+ let learning_rate = 0.05;
+ let epochs = 10;
+
+ // Initialize a stochastic gradient descent optimizer to update parameters
+ let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
+ let test_images = m.test_images.to_device(&dev)?;
+ let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
+
+ for epoch in 1..epochs {
+ // Perform forward pass on MNIST data
+ let logits = model.forward(&train_images)?;
+ let log_sm = ops::log_softmax(&logits, D::Minus1)?;
+
+ // Compute Negative Log Likelihood loss
+ let loss = loss::nll(&log_sm, &train_labels)?;
+
+ // Perform backward pass and update weights
+ sgd.backward_step(&loss)?;
+
+ // Evaluate model on test set
+ let test_logits = model.forward(&test_images)?;
+ let sum_ok = test_logits
+ .argmax(D::Minus1)?
+ .eq(&test_labels)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_scalar::()?;
+ let test_accuracy = sum_ok / test_labels.dims1()? as f32;
+ println!(
+ "{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
+ loss.to_scalar::()?,
+ test_accuracy
+ );
+ }
+ Ok(())
+}
+```
+
+Finally, let's implement our main function:
+
+```rust
+pub fn main() -> anyhow::Result<()> {
+ let m = candle_datasets::vision::mnist::load()?;
+ return training_loop(m);
+}
+```
+
+Let's execute the training process:
+
+```bash
+$ cargo run --release
+
+> 1 train loss: 2.35449 test acc: 0.12%
+> 2 train loss: 2.30760 test acc: 0.15%
+> ...
+```
\ No newline at end of file
diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md
index fb6f9e51f6..e8d8b267db 100644
--- a/candle-book/src/inference/hub.md
+++ b/candle-book/src/inference/hub.md
@@ -11,8 +11,8 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas
```rust
# extern crate candle_core;
-# extern crate candle_hf_hub;
-use candle_hf_hub::api::sync::Api;
+# extern crate hf_hub;
+use hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
@@ -50,8 +50,8 @@ Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle_core;
# extern crate candle_nn;
-# extern crate candle_hf_hub;
-# use candle_hf_hub::api::sync::Api;
+# extern crate hf_hub;
+# use hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());
diff --git a/candle-book/src/tracing.md b/candle-book/src/tracing.md
new file mode 100644
index 0000000000..dbaa80f012
--- /dev/null
+++ b/candle-book/src/tracing.md
@@ -0,0 +1,68 @@
+# Tracing
+
+Tracing is a powerful tool for identifying performance issues and bottlenecks in code.
+
+> Profiling on GPUs is trickier due to asynchronous execution, see the [GPU section](#gpu).
+
+## Overview
+
+Candle uses the [tracing](https://docs.rs/tracing/latest/tracing/) crate for instrumentation.
+
+To try it out, run an example in `candle-examples` with the `--tracing` flag.
+This generates a trace file, typically named `trace-.json`.
+You can view the trace in Chrome by navigating to `chrome://tracing/`, clicking **Load**, and selecting the generated trace file.
+
+## Adding Tracing
+
+Candle includes built-in tracing for many internal operations, using [spans](https://docs.rs/tracing/latest/tracing/struct.Span.html) to mark key points of execution.
+
+To add custom tracing in your code, you can define a span like this:
+
+```rust
+let span = tracing::span!(tracing::Level::TRACE, name);
+```
+
+Then, to record the span during execution, create a guard:
+
+```rust
+let _enter = span.enter();
+```
+
+This guard will record the span's duration, from when it is created to when it is dropped, into a global data structure managed by the tracing crate.
+
+## Recording and Saving a Trace
+
+To capture and save trace data, you need to configure the tracing system with an output format. Candle uses the [tracing_subscriber](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) and [tracing_chrome](https://docs.rs/tracing-chrome/latest/tracing_chrome/) crates.
+
+The snippet below sets up a Chrome compatible recorder that logs all tracing activity between creation and drop of the guard:
+
+```rust
+use tracing_chrome::ChromeLayerBuilder;
+use tracing_subscriber::prelude::*;
+
+let _guard = {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ guard
+};
+```
+
+## GPU
+
+When using CUDA, Metal, or other asynchronous GPU backends, tracing may produce misleading timing data because operations are queued rather than executed immediately.
+
+### CUDA
+
+For CUDA-specific profiling, you have two options:
+
+1. Set the environment variable `CUDA_LAUNCH_BLOCKING=1` which forces synchronous execution. This makes trace timings more accurate, at the cost of reduced performance.
+2. Use [NVIDIA's Nsight Systems](https://developer.nvidia.com/nsight-systems) (`nsys profile` and `nsys-ui`) which are designed specifically for profiling asynchronous CUDA executions.
+
+We recommend using NVIDIA's Nsight Systems when possible, as it offers accurate performance data without altering typical execution patterns. In contrast, setting the `CUDA_LAUNCH_BLOCKING` environment variable forces synchronous execution, which can significantly alter execution behavior.
+
+#### Performance Profiling with NVIDIA Nsight Systems
+
+1. Generate an `.nsys-rep` file containing performance data ([docs](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#example-single-command-lines))
+ - Run `nsys profile --trace cuda,nvtx,osrt --gpu-metrics-device=all --output profile_run ./target/debug/... --prompt "whatever "`
+1. Open the generated `.nsys-rep` report file in Nsight Systems GUI
+ - File > Open
\ No newline at end of file
diff --git a/candle-book/src/wgpu/README.md b/candle-book/src/wgpu/README.md
new file mode 100644
index 0000000000..47dcb235d2
--- /dev/null
+++ b/candle-book/src/wgpu/README.md
@@ -0,0 +1,111 @@
+# Wgpu Installation
+To use the wgpu backend, you must enable the wgpu feature.
+In the code, use the `new_wgpu` function to create a new wgpu device.
+```rust
+//on the browser, the async method must be used.
+let device = Device::new_wgpu_async(0).await?
+
+//or
+
+let device = Device::new_wgpu(0)?
+
+//or
+
+//Pass additional configuration, e.g. the wgpu backend to be used (vulkan, dx12 or metal).
+let device = Device::new_wgpu_config_async(0, config).await?
+
+//or
+
+let device = Device::new_wgpu_config(0, config)?
+```
+
+## GPU Storage Query Limitation in WGPU
+
+Currently, WGPU does not provide a way to query the available storage of a GPU device. As a result, the Candle implementation for WGPU cannot determine the number of buffers that can be cached or when existing buffers should be deleted.
+
+To address this limitation, the `buffer_cached_max_allowed_size` property is included in the device creation configuration. This property allows users to specify the maximum amount of memory, in bytes, that Candle is permitted to allocate for buffers. By default, this value is set to 8 GB.
+
+## Feature Support Table
+
+| Feature | Support Status | Notes |
+|-----------------------------|-------------------------------------------------|--------------------------------------------------------------------|
+| **Data Types** | | |
+| f32 | ✅ Supported | |
+| u32 | ✅ Supported | |
+| u8 | ⚠️ Only Output of Cmp | *Only f32, I32 and U32 are available in a webGpu shader |
+| i64 | ⚠️ Supported Native | |
+| f64 | ⚠️ Supported Native | |
+| f16 | ⚠️ Only in Quantized Matrices | |
+| bf16 | ❌ Not Supported | |
+| **Operations** | | All operations support non-contiguous arrays |
+| Unary Operations | ✅ Supported | |
+| Binary Operations | ✅ Supported | |
+| MatMul | ✅ Supported | |
+| Reduce Operations | ✅ Supported | Sum, Min, Max, (ArgMax, ArgMin works only if continues Dimensions are reduced) |
+| Conv2d | ✅ Supported | |
+| Conv2dTranspose | ✅ Supported | |
+| Conv1d | ✅ Supported | |
+| Conv1dTranspose | ✅ Supported | |
+| Index Select | ✅ Supported | |
+| Where_cond | ✅ Supported | |
+| Pool2dMax | ✅ Supported | |
+| Pool2dAvg | ✅ Supported | |
+| Upsample | ✅ Supported | |
+| Gather | ✅ Supported | |
+| Scatter_add | ✅ Supported | |
+| Index_add | ✅ Supported | |
+| Quantized Matrices | ✅ Supported | |
+| **Not Implemented** | | |
+| ArgSort | ❌ Not Implemented | |
+
+
+# Usage in the Browser
+It is not possible to synchronously request a device, read a gpu memory or synchronise the device in the browser.
+There are synchronous methods for these operations, but as these methods will just block the current thread, they will not work in the browser and will fail.
+If you want to target the browser, you need to use the async methods.
+
+The following code demonstrates how to use wgpu in the browser:
+```rust
+use candle_core::{Device, Tensor};
+
+//use the await method to create a device, this must be asynchronous
+let device = Device::new_wgpu_async(0, config).await?
+
+let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
+let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
+
+let c = a.matmul(&b)?;
+
+//If we want to synchronise with the device, we must use the async function.
+device.synchonize_async();
+
+//We need to asynchronously copy the gpu buffer back to the cpu,
+//to_device() will not work
+let c = c.to_device_async(&Device:Cpu).await?;
+//or c.to_vec2_async().await?
+console_log!("{c}");
+Ok(())
+```
+** Note that the above limitation only applies if the browser is targeted; a native program can still use the same sync functions.
+
+# Example Projects
+All example projects, as well as the WAM examples, can be used with the wgpu backend.
+
+In order to use **WGPU** add `--features wgpu` to the example command line.
+e.g:
+```bash
+cargo run --example stable-diffusion --release --features="wgpu" -- --prompt "Anthropomorphic cat dressed as a fire fighter" --sd-version v1-5
+```
+
+
+# known problems
+- not all dtypes are supported: f32, u32 is implemented for most and u8 for a cmp and whereCond.
+ f64 or i64 is supported for native programs. WebGpu has no support for f64 or i64 or u8 dtypes
+ (There is a f16 extension in the webGpu Spec, but this is currently not supported by wgpu(https://github.com/gfx-rs/wgpu/issues/4384))
+- Reduce Implementation error: When using ArgMin, ArgMax with non continues reduction dimensions will probably not work. e.g if dim 0 and 2 are reduced. The current implementation will first reduce dim 2, and afterwards dim 0. This approach will not work for ArgMin/ArgMax as after the first reduction the type and source values changed.
+- Buffer size limitation:
+ Depending on the driver used, it may not be possible to create a large enough buffer.
+ Also, you may be able to create a large buffer, but not be able to bind to the entire buffer in a single operation.
+- Browser performance worse than native:
+ The shaders have been optimized for an NVIDIA GPU using a native Vulkan driver.
+ Performance may not be optimal on other platforms or GPUs. Browser performance has been shown to be slower than native.
diff --git a/candle-book/src/wgpu/custom_shader.md b/candle-book/src/wgpu/custom_shader.md
new file mode 100644
index 0000000000..a8da378457
--- /dev/null
+++ b/candle-book/src/wgpu/custom_shader.md
@@ -0,0 +1,86 @@
+# Custom shader
+
+In "candle-core/examples/wgpu_basics.rs" is a sample project that shows how to write a custom WGPU shader.
+
+1. Define a LoaderIndex.
+When you add your custom pipeline to the queue, you need to use a unique identifier to distinguish your shader from the default shader provided with Candle, or the shaders of other modules.
+The macro `create_loader` will generate a unique index for this purpose at compile time.
+```rust
+wgpu_compute_layer::create_loader!(MyCustomLoader);
+```
+
+2. Define a ShaderLoader
+A ShaderLoader is an object that implements the `ShaderLoader' trait.
+It is responsible for returning the source code of a .wgsl shader file, as well as the name of the entry point.
+
+```rust
+impl wgpu_compute_layer::ShaderLoader for MyCustomLoader{
+ fn load(&self, _ : wgpu_compute_layer::ShaderIndex) -> &str {
+ return "YOUR SHADER CODE GOES HERE";
+ }
+
+ fn get_entry_point(&self, _ : wgpu_compute_layer::PipelineIndex) -> &str {
+ return "ENTRY POINT NAME GOES HERE"
+ }
+}
+```
+Instead of creating multiple ShaderLoaders, your ShaderLoader can handle multiple files using the index parameter. (Up to 65536 can be handled).
+Each file can also have multiple compute entry points, which you can differentiate using the PipelineIndex parameter.
+
+3. Add the ShaderLoader to the WGPU device.
+You can get a reference to the WgpuDevice from WgpuStorage.device() (e.g. inside a CustomOp),
+or by pattern matching the candle device.
+```rust
+ wgpu_device.add_wgpu_shader_loader(MyCustomLoader::LOADER_INDEX, || {MyCustomLoader{}});
+```
+This will add your shader loader at the specified index.
+For example, your index created by the create_loader macro is 13.
+Later on when we enqueue a custom shader we will use this index to tell the wgpu backend that we want to enqueue one of our custom shaders.
+For example, if you enqueue (Loader=13, Shader=0, EntryPoint=0), the wgpu system will look for that pipeline in a hashmap. If it does not find it, it will ask the shader loader at index 13 for the first shader and the name of the first entry point of that shader.
+
+4. Queue your shader:
+
+ To add a pipeline to the queue, we need to use the following commands:
+ 1. Define a reference to the metastructure.
+ Here we can pass additional meta information for the operation
+ ```rust
+ let mut queue = wgpu_device.get_queue();
+ queue.add(42);
+ queue.add(13);
+ ..
+ ```
+ 2. Define the pipeline to use.
+ Use your ShaderLoaderIndex to define which pipeline and entry point to use.
+ ```rust
+ let pipeline = queue.get_pipeline(PipelineIndex::new(ShaderIndex::new(MyCustomLoader::LOADER_INDEX, 0), 0));
+ //or
+ let pipeline = queue.get_pipeline_const(PipelineIndex::new(ShaderIndex::new(MyCustomLoader::LOADER_INDEX, 0), 0), [42, 12]); //CONSTV_0 = 42, CONSTV_1 = 12
+ ```
+ It is also possible to define webgpu override const values using the get_pipeline_const function.
+ Each time this const parameter changes, a new shader is compiled. The constant is compiled into the shader. This can improve performance.
+ But remember that shader compilation also takes time, so if the constant value changes frequently, you may want to add the value as a meta parameter instead.
+
+ The names of the following const parameters must be `CONSTV_{N}`.
+
+ 3. Define the Bindgroup
+ The bindgroup defines the input and output buffers for your operations.
+ ```rust
+ let bing_group = wgpu_device.create_bind_group_input0(*output_buffer.buffer(), candle_core::DType::U32.into());
+ ```
+ In general, there are 4 possible Bindgroup types:
+ - Bindgroup0 - V_Dest(Binding 0), V_Meta(Binding 1)
+ - Bindgroup1 - V_Dest(Binding 0), V_Meta(Binding 1), V_Input1(Binding 2)
+ - Bindgroup2 - V_Dest(Binding 0), V_Meta(Binding 1), V_Input1(Binding 2), V_Input2(Binding 3)
+ - Bindgroup3 - V_Dest(Binding 0), V_Meta(Binding 1), V_Input1(Binding 2), V_Input2(Binding 3), V_Input3(Binding 4)
+
+ 4. add the command to the queue:
+ ```rust
+ queue.enqueue_workgroups(
+ pipeline,
+ bind_group,
+ x,
+ y,
+ z,
+ workloadestimate //e.g. m*k*n for a matmul
+ );
+ ```
\ No newline at end of file
diff --git a/candle-book/src/wgpu/debug_performance.md b/candle-book/src/wgpu/debug_performance.md
new file mode 100644
index 0000000000..db96bf421e
--- /dev/null
+++ b/candle-book/src/wgpu/debug_performance.md
@@ -0,0 +1,154 @@
+# Debugging WGPU Performance
+
+There are three main ways to measure and debug the performance of `wgpu` devices:
+
+1. Tracing
+2. Benchmarking
+3. Recording all queued operations
+
+---
+
+## 1. Tracing
+
+You can use tracing to measure internal durations, for example:
+
+- How long it takes to create buffers
+- How long it takes to create bind groups
+- How long it takes to encode commands
+- How long it takes to wait for the GPU to be ready
+
+To add tracing to a project, see the [`tracing`](../tracing.md) page.
+
+In addition:
+
+- `wgpu_compute_layer` has a dependency on `tracing` with
+ `features = ["release_max_level_off"]` in its `Cargo.toml`.
+- With this configuration, tracing is effectively disabled in release builds.
+- To use tracing, you must either:
+ - run a debug build, **or**
+ - remove the `features = ["release_max_level_off"]` entry from the `tracing` dependency in `Cargo.toml`.
+
+---
+
+## 2. Benchmarking
+
+The directory:
+
+```text
+candle-core/benches/benchmarks/...
+```
+
+contains various benchmarks.
+
+For example, `matmul_wgpu` can be used to:
+
+* override the `matmul` implementation used, and
+* test the performance of different `matmul` implementations under different scenarios.
+
+Use these benches to compare implementations and understand performance characteristics on your hardware.
+
+---
+
+## 3. Recording and Replaying WGPU Operations
+
+To debug and optimize the performance of WGPU shaders in a model, you can **record**, **inspect**, and **replay** all WGPU commands used during execution.
+
+### 3.1 Features: `wgpu_debug` and `wgpu_debug_serialize`
+
+There are two related features:
+
+#### `wgpu_debug`
+
+Enable this feature to record all WGPU commands executed during the model’s runtime.
+
+```bash
+# Example
+cargo build --features wgpu_debug
+```
+
+When `wgpu_debug` is enabled:
+
+* All queued WGPU commands (pipelines, bind groups, buffers, dispatches, etc.) are recorded.
+* At the end of execution, you can dump them to disk using `log_debuginfo_to_file` (see Step 2).
+* The recorded data can later be **replayed** to benchmark or debug performance.
+
+#### `wgpu_debug_serialize`
+
+This feature is more lightweight:
+
+* It **does not** record any commands at runtime.
+* Instead, it adds `serde::Serialize` derives (and related metadata) to pipelines, bind groups, shader info, etc.
+* This is useful when you want to **load and work with** the files produced by a `wgpu_debug` run (for example, to simulate or analyze them in another process or crate), without enabling full command recording again.
+
+Typical workflow:
+
+1. Run your model once with `wgpu_debug` enabled to generate the debug files.
+2. In another tool/binary/crate, enable `wgpu_debug_serialize` to **deserialize and inspect** those recorded files, replay commands, or run simulations.
+
+---
+
+### Step 2: Log debugging information to files
+
+At the end of the model execution (with `wgpu_debug` enabled), call `log_debuginfo_to_file` to write all recorded information into a set of files:
+
+```rust
+#[cfg(feature = "wgpu_debug")]
+{
+ device
+ .as_wgpu_device()
+ .unwrap()
+ .log_debuginfo_to_file("{OUTPUT_PATH}", "MODEL_NAME", "VERSION_NAME")?;
+ // Example:
+ // log_debuginfo_to_file("", "llama2c", "5.0")?;
+}
+```
+
+This will create four files:
+
+* `*_measurements.json`
+ Contains performance metrics for all used shaders and pipelines.
+
+* `*_shaders.json`
+ Contains all created shaders and their used constants.
+ This is useful to detect situations where **many slightly different shaders** are created (shader creation is expensive) and where using **pipeline parameters instead of constants** might be more efficient.
+
+* `*_used_consts.json`
+ Maps `ConstsId` to the actual constants used.
+ This file is required when **replaying** the recorded commands.
+
+* `*_used_pipelines.json`
+ Contains all pipeline, bind group, constants, and buffer information needed to **replay** the recorded commands.
+
+---
+
+### Step 3: Analyze and Replay the Generated Debug Files
+
+You can:
+
+* Inspect the generated JSON files manually, **or**
+* Use the provided helper script for automated benchmarking and analysis.
+
+The script is located at:
+
+```text
+candle-wasm-examples/candle-test/src/bin/candle-test.rs
+```
+
+Example invocations:
+
+```bash
+# Run natively
+cargo run --bin candle-test --release
+
+# Run in the browser (via wasm)
+cargo xtask run-wasm -- --release --bin candle-test
+```
+
+At the top of `candle-test.rs`, configure the relevant `DEBUG` constants to point to your generated `*_used_consts.json` and `*_used_pipelines.json` files. Once configured, the script will:
+
+* Replay and benchmark each recorded command, and
+* Print all commands sorted in **reverse order of total execution duration** (slowest first).
+
+This makes it straightforward to spot performance bottlenecks and problematic shaders/pipelines.
+
+> Depending on your setup, you can run this analysis either natively or in the browser, using the same recorded debug data.
diff --git a/candle-book/src/wgpu/debug_shader_logic.md b/candle-book/src/wgpu/debug_shader_logic.md
new file mode 100644
index 0000000000..f1eed6bd87
--- /dev/null
+++ b/candle-book/src/wgpu/debug_shader_logic.md
@@ -0,0 +1,77 @@
+# Debugging WGPU Shader Logic
+
+This page describes how to debug **incorrect results and shader logic errors** in the WGPU backend by recording a **complete execution trace** of GPU commands.
+
+This mechanism is intended for **correctness debugging** (unit tests, small reproductions), not for performance profiling.
+
+Unlike `wgpu_debug`, which records only **pipeline statistics and timing**, this mechanism records:
+
+* Full shader source code
+* All dispatched pipelines
+* Copies of all input and output buffers
+* Dispatch order and parameters
+
+Because it records full buffers and shaders, the generated data can be **very large** and should only be used with small test cases.
+
+---
+
+## Enabling
+
+This feature is available with the `wgpu_debug` feature:
+
+```bash
+cargo build --features="wgpu wgpu_debug"
+```
+
+---
+
+## Recording Commands
+
+Wrap the code you want to debug:
+
+```rust
+#[cfg(feature = "wgpu_debug")]
+{
+ let wgpu = device.as_wgpu_device()?;
+ wgpu.inner_device().start_recording_commands();
+
+ // Run the operations you want to debug here (prefer small unit tests).
+
+ wgpu.inner_device()
+ .stop_recording_commands(&"PATH TO ZIP FILE TO WRITE ALL DISPATCHES TO")?;
+}
+```
+
+Everything executed between `start_recording_commands` and `stop_recording_commands` is recorded into a ZIP file.
+
+---
+
+## Synchronization
+
+Make sure all work has completed before stopping the recording:
+
+* Synchronize the device, or
+* Read back a buffer from the GPU
+
+Otherwise, the recording may be incomplete.
+
+---
+
+## Difference to `wgpu_debug`
+
+* `wgpu_debug` (performance):
+
+ * Records pipeline names, call counts, timing
+ * No shader code or buffers
+
+* Full command recording (this page):
+
+ * Records full shader code and all buffers
+ * Intended for debugging incorrect results
+
+---
+
+## Notes
+
+* Use only for **small tests** — recordings can become very large.
+* Intended for native debugging (not supported in the browser).
\ No newline at end of file
diff --git a/candle-book/src/wgpu/implementation_detail.md b/candle-book/src/wgpu/implementation_detail.md
new file mode 100644
index 0000000000..91e6d984e2
--- /dev/null
+++ b/candle-book/src/wgpu/implementation_detail.md
@@ -0,0 +1,58 @@
+# Implementation details:
+
+## Kernels:
+This implementation uses a custom wgsl kernel system in candle-wgpu-kernels.
+For the syntax look at [`.pwgsl files`](./pwgsl_files.md)
+At compile time, files ending in `.pwgsl` are processed by the build.rs and included with the following DTYPE-Variants:
+["F32", "U32", "I64", "F64", "F16", "U8"] defining the TYPE name as global defines.
+
+
+In addition, a rust module is defined for each .pwgsl shader file, which contains information about the compute shader functions contained in that file. When called from the candle_backend, these automatically generated mappings are used to call the kernel.
+
+In addition, the build.rs further truncates the wgsl files (removes all spaces, truncates variable names, constant override names or function names and removes unused global variables and functions).
+
+# Implementation Details
+
+## Kernels
+
+This implementation utilizes a custom WGSL kernel system provided by `candle-wgpu-kernels`.
+
+For details on the syntax, see **[pwgsl_files.md]**.
+
+### Kernel Preprocessing
+
+At compile time, `.pwgsl` files are processed by `build.rs` and included with the following **`DTYPE` variants**, which define the `TYPE` name as global preprocessor defines:
+- `"F32"`
+- `"U32"`
+- `"I64"`
+- `"F64"`
+- `"F16"`
+- `"U8"`
+
+### Rust Module Generation
+
+For each `.pwgsl` shader file, a corresponding **Rust module** is automatically generated.
+This module contains metadata about the compute shader functions defined in the file.
+When a kernel is invoked from `candle_backend`, these auto-generated mappings ensure the correct function is called.
+
+### WGSL Optimization
+
+Additionally, `build.rs` performs **WGSL optimization**, which includes:
+- **Whitespace removal** – Stripping unnecessary spaces to reduce file size.
+- **Variable name truncation** – Shortening variable names, constant overrides, and function names.
+- **Dead code elimination** – Removing unused global variables and functions.
+
+## Cache system:
+All called wgpu functions are not executed directly, but first queued in an internal queue inside the WgpuDevice object.
+All previously queued functions are only flushed to the GPU when a buffer is requested to be read, the device is synchronised, or data is copied from the CPU to the wgpu device.
+
+When flushed, previously created buffers and bindgroups are reused using a custom implemented cache system. (For example, to generate an image using Wuerstchen, more than 2_000_000 commands will be queued, the current cache system will only create about 8000 buffers and 100_000 bindgroups for these commands (instead of creating 2_000_000 bindgroups and output buffers for each command)).
+
+Objects:
+BufferReference(an object representing a virtual buffer. It may or may not be currently associated with an actual CachedBuffer)
+CachedBuffer(an object representing a Wgpu::Buffer)
+CachedBindgroup(An object representing a wgpu::bindgroup)
+
+All these 3 objects are held in a separate vec storage.
+Objects can be read or written using a reference (an index into the vec and a timestamp value).
+When an entry is deleted, the timestamp value at that index is incremented to ensure that no further entries are made.
\ No newline at end of file
diff --git a/candle-book/src/wgpu/pwgsl_files.md b/candle-book/src/wgpu/pwgsl_files.md
new file mode 100644
index 0000000000..c8086be22b
--- /dev/null
+++ b/candle-book/src/wgpu/pwgsl_files.md
@@ -0,0 +1,119 @@
+## Kernel Files (`.pwgsl`)
+
+The WGPU kernels are located in the `candle-wgpu-kernels` crate as `.pwgsl` files.
+
+A `.pwgsl` file is a WGSL shader file that is preprocessed using a C-like preprocessor.
+This allows writing reusable and configurable shader code using includes, macros, and conditional compilation.
+
+---
+
+## File Inclusion
+
+```c
+#include "FILENAME"
+```
+
+Inserts the contents of another file at this location.
+
+---
+
+## Macro Definitions
+
+### Simple Defines
+
+```c
+#define KEY 42
+```
+
+Replaces all occurrences of `KEY` with `42`. Only whole identifiers are replaced:
+
+* ✅ `if KEY == 42` → `if 42 == 42`
+* ❌ `if KEYS == 42` (unchanged, `KEYS` is not a full match)
+
+---
+
+### Function-Like Defines
+
+```c
+#define MAX(a, b) (a > b) ? a : b
+```
+
+Allows function-like macros.
+
+Example:
+
+```c
+MAX(2+2, 42)
+```
+
+Expands to:
+
+```c
+(2+2 > 42) ? 2+2 : 42
+```
+
+---
+
+### Computed Defines
+
+```c
+#definec DEFINE_NAME EXPRESSION
+```
+
+A computed define evaluates a mathematical expression at preprocess time.
+
+Example:
+
+```c
+#define KEY 42
+#definec MyDefine (42 + KEY)
+let x = MyDefine;
+```
+
+Expands to:
+
+```c
+let x = 84;
+```
+
+---
+
+## Conditional Compilation
+
+The preprocessor supports the following C-like conditional directives:
+
+```c
+#if CONDITION // True if CONDITION evaluates to non-zero
+#ifdef DEFINE_NAME // True if DEFINE_NAME is defined
+#ifndef DEFINE_NAME // True if DEFINE_NAME is NOT defined
+#elif CONDITION // Alternative condition
+#elifdef DEFINE_NAME // Alternative if DEFINE_NAME is defined
+#elifndef DEFINE_NAME // Alternative if DEFINE_NAME is NOT defined
+#endif // Ends the conditional block
+```
+
+---
+
+## Multi-Line Preprocessor Blocks (`#pp_begin` / `#pp_end`)
+
+In addition to normal `#define`, the preprocessor supports **multi-line macro blocks** using:
+
+```c
+#pp_begin DEFINENAME(data, tid)
+ ...
+#pp_end
+```
+
+This defines a macro named `DEFINENAME` with parameters (`data`, `tid`) whose body can span **multiple lines**.
+
+Key properties:
+
+* Works like a function-like `#define`, but allows multi-line bodies.
+* The body may contain **other preprocessor directives** (`#if`, `#define`, `#include`, etc.).
+* All inner preprocessing is executed using the **context and state at the time the macro is expanded**, not when it is defined.
+
+This makes `#pp_begin` useful for:
+
+* Defining reusable shader kernels or code templates
+* Generating complex control flow or binding logic
+* Sharing blocks of code that depend on compile-time configuration
\ No newline at end of file
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 4ffc869ff8..f9dd08fdd8 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -14,12 +14,15 @@ accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { workspace = true, optional = true }
-metal = { workspace = true, optional = true}
+objc2-metal = { workspace = true, optional = true }
+objc2-foundation = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
+float8 = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }
+libm = { workspace = true }
memmap2 = { workspace = true }
num-traits = { workspace = true }
num_cpus = { workspace = true }
@@ -28,25 +31,57 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
-ug = { workspace = true }
-ug-cuda = { workspace = true, optional = true }
-ug-metal = { workspace = true, optional = true }
yoke = { workspace = true }
zip = { workspace = true }
+#Wgpu Dependencies:
+#wgpu = { workspace = true, optional = true }
+bytemuck = { workspace = true, optional = true }
+tracing = {workspace = true, optional = true, features = ["release_max_level_off"]}
+
+pollster = { workspace = true, optional = true }
+candle-wgpu-kernels = { workspace = true, optional = true }
+wgpu-compute-layer = { workspace = true, optional = true }
+
+[target.'cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))'.dependencies]
+candle-ug = { workspace = true, optional = true }
+
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
criterion = { workspace = true }
-
[features]
default = []
-cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
+cuda = ["cudarc", "dep:candle-kernels", "candle-ug?/cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
+nccl = ["cuda", "cudarc/nccl"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
-metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
+metal = [
+ "dep:objc2-metal",
+ "dep:objc2-foundation",
+ "dep:candle-metal-kernels",
+ "candle-ug?/metal",
+]
+ug = ["dep:candle-ug"]
+wgpu = [
+ "dep:bytemuck",
+ "dep:tracing",
+ "dep:pollster",
+ "dep:candle-wgpu-kernels",
+ "dep:wgpu-compute-layer",
+]
+wgpu_debug = [
+ "wgpu",
+ "candle-wgpu-kernels/wgpu_debug_serialize",
+ "wgpu-compute-layer/wgpu_debug",
+]
+wgpu_debug_serialize = [
+ "wgpu",
+ "candle-wgpu-kernels/wgpu_debug_serialize",
+ "wgpu-compute-layer/wgpu_debug_serialize",
+]
[[bench]]
name = "bench_main"
@@ -55,3 +90,12 @@ harness = false
[[example]]
name = "metal_basics"
required-features = ["metal"]
+
+[[example]]
+name = "cuda_basics"
+required-features = ["cuda"]
+
+
+[[example]]
+name = "wgpu_basics"
+required-features = ["wgpu"]
diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs
index 2e1816fd71..36f0289d61 100644
--- a/candle-core/benches/bench_main.rs
+++ b/candle-core/benches/bench_main.rs
@@ -1,12 +1,25 @@
mod benchmarks;
use criterion::criterion_main;
+
criterion_main!(
benchmarks::affine::benches,
+ benchmarks::binary::benches,
+ benchmarks::broadcast::benches,
+ benchmarks::copy::benches,
+ benchmarks::conv_transpose2d::benches,
benchmarks::matmul::benches,
+ benchmarks::qmatmul::benches,
+ benchmarks::matmul_wgpu::benches,
+ benchmarks::matmul_quantized::benches,
benchmarks::random::benches,
+ benchmarks::reduce::benches,
+ benchmarks::unary::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
+ benchmarks::conv2d::benches,
benchmarks::qmatmul::benches,
- benchmarks::unary::benches
+ benchmarks::unary::benches,
+ benchmarks::binary::benches,
+ benchmarks::copy::benches
);
diff --git a/candle-core/benches/benchmarks/affine.rs b/candle-core/benches/benchmarks/affine.rs
index c1004c6c6c..08379d76b9 100644
--- a/candle-core/benches/benchmarks/affine.rs
+++ b/candle-core/benches/benchmarks/affine.rs
@@ -1,6 +1,7 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
-use criterion::{black_box, criterion_group, Criterion, Throughput};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
use std::time::Instant;
fn run(a: &Tensor) {
@@ -35,8 +36,16 @@ fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
- run_affine_benchmark(c, &device, DType::F16, "affine_f16");
- run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
+ if device.is_dtype_available(DType::F16) {
+ run_affine_benchmark(c, &device, DType::F16, "affine_f16");
+ }
+ if device.is_dtype_available(DType::BF16) {
+ run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
+ }
+ if device.is_dtype_available(DType::F8E4M3) {
+ #[cfg(not(feature = "metal"))]
+ run_affine_benchmark(c, &device, DType::F8E4M3, "affine_fp8");
+ }
}
}
diff --git a/candle-core/benches/benchmarks/binary.rs b/candle-core/benches/benchmarks/binary.rs
new file mode 100644
index 0000000000..38f44ffcba
--- /dev/null
+++ b/candle-core/benches/benchmarks/binary.rs
@@ -0,0 +1,59 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{DType, Device, Tensor};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
+use std::time::Instant;
+
+fn run(lhs: &Tensor, rhs: &Tensor) -> Tensor {
+ lhs.mul(rhs).unwrap()
+}
+
+fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ let b = 1;
+ let m = 1024;
+ let k = 1024;
+
+ let lhs = Tensor::arange(0.0f32, (b * m * k) as f32, device)
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap()
+ .reshape((b, m, k))
+ .unwrap();
+
+ let rhs = Tensor::arange(0.0f32, (b * m * k) as f32, device)
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap()
+ .reshape((b, m, k))
+ .unwrap();
+
+ let flops = 2 * b * m * k * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&lhs), black_box(&rhs));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ for dtype in [DType::F32, DType::BF16, DType::F16] {
+ let name = format!("binary_mul_{dtype:?}");
+ if device.is_dtype_available(dtype) {
+ run_unary_benchmark(c, &device, dtype, &name);
+ }
+ }
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/broadcast.rs b/candle-core/benches/benchmarks/broadcast.rs
new file mode 100644
index 0000000000..99077ff9fa
--- /dev/null
+++ b/candle-core/benches/benchmarks/broadcast.rs
@@ -0,0 +1,51 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{DType, Device, Tensor};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
+use std::time::Instant;
+
+fn run(w: &Tensor, bias: &Tensor) {
+ w.broadcast_add(bias).unwrap();
+}
+
+fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ // We simulate a candle-nn style conv2d + bias forward pass.
+ let batch_size = 1;
+ let ch = 1;
+ let m = 126;
+ let bias_size = 128;
+
+ let x = Tensor::ones((batch_size, ch, m, m), dtype, device).unwrap();
+ let bias = Tensor::ones((1, bias_size, 1, 1), dtype, device).unwrap();
+
+ let flops = batch_size * ch * m * bias_size * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&x), black_box(&bias));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_benchmark(c, &device, DType::F32, "broadcast_add_f32");
+ if device.is_dtype_available(DType::F16){
+ run_benchmark(c, &device, DType::F16, "broadcast_add_f16");
+ }
+ if device.is_dtype_available(DType::BF16){
+ run_benchmark(c, &device, DType::BF16, "broadcast_add_bf16");
+ }
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/conv2d.rs b/candle-core/benches/benchmarks/conv2d.rs
new file mode 100644
index 0000000000..f27cf5cd55
--- /dev/null
+++ b/candle-core/benches/benchmarks/conv2d.rs
@@ -0,0 +1,68 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{DType, Device, Tensor};
+use criterion::{black_box, criterion_group, Criterion};
+use std::time::Instant;
+
+fn run(input: Tensor, weight: Tensor) {
+ input.conv2d(&weight, 1, 1, 1, 1).unwrap();
+}
+
+fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ // const B: usize = 1;
+ // const C: usize = 320;
+ // const C_OUT : usize = 640;
+ // const M: usize = 128;
+ // const K: usize = 128;
+ // const K_SIZE: usize = 2;
+
+ const B: usize = 2;
+ const C: usize = 1;
+ const C_OUT: usize = 1;
+ const M: usize = 24;
+ const K: usize = 24;
+ const K_SIZE: usize = 3;
+
+ let weight = Tensor::ones((C_OUT, C, K_SIZE, K_SIZE), dtype, device)
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap();
+
+ // let weight = Tensor::ones((C_OUT, K_SIZE, K_SIZE, C), dtype, device)
+ // .unwrap()
+ // .to_dtype(dtype)
+ // .unwrap()
+ // .transpose(3, 1).unwrap();
+
+ let input = Tensor::ones((B, C, M, K), dtype, device).unwrap();
+
+ println!("weight: {:?}", weight.layout());
+ println!("input: {:?}", input.layout());
+
+ let flops = B * C * M * K * K_SIZE * K_SIZE;
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(criterion::Throughput::Bytes(flops as u64 * 4));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(input.clone()), black_box(weight.clone()));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let device = BenchDeviceHandler::new().unwrap();
+ for d in device.devices {
+ run_conv2d_benchmark(c, &d, DType::F32, "conv2d_f32");
+ if d.is_dtype_available(DType::F16) {
+ run_conv2d_benchmark(c, &d, DType::F16, "conv2d_f16");
+ }
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/conv_transpose2d.rs b/candle-core/benches/benchmarks/conv_transpose2d.rs
index 7b252ec6f9..484dd9942c 100644
--- a/candle-core/benches/benchmarks/conv_transpose2d.rs
+++ b/candle-core/benches/benchmarks/conv_transpose2d.rs
@@ -1,6 +1,7 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
-use criterion::{black_box, criterion_group, Criterion, Throughput};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
use std::time::Instant;
fn run(
@@ -51,8 +52,12 @@ fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
- run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
- run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
+ if device.is_dtype_available(DType::F16) {
+ run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
+ }
+ if device.is_dtype_available(DType::BF16) {
+ run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
+ }
}
}
diff --git a/candle-core/benches/benchmarks/copy.rs b/candle-core/benches/benchmarks/copy.rs
new file mode 100644
index 0000000000..00eff3dca6
--- /dev/null
+++ b/candle-core/benches/benchmarks/copy.rs
@@ -0,0 +1,39 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{Device, Tensor, WithDType};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
+use std::time::Instant;
+
+fn run_copy_mask_benchmark(c: &mut Criterion, device: &Device, name: &str) {
+ let batch_size = 128;
+ let in_seq_len = 1;
+ let kv_seq_len = 1024;
+
+ let attn_mask = vec![vec![vec![D::zero(); kv_seq_len]; in_seq_len]; batch_size];
+ let size_in_bytes = batch_size * in_seq_len * kv_seq_len * D::DTYPE.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(size_in_bytes as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let attn_masks = vec![attn_mask.clone(); iters as usize];
+ let start = Instant::now();
+ for attn_mask in attn_masks.into_iter() {
+ let tensor = Tensor::new(black_box(attn_mask), device).unwrap();
+ black_box(tensor);
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_copy_mask_benchmark::(c, &device, "copy_mask");
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/matmul.rs b/candle-core/benches/benchmarks/matmul.rs
index 9d67e642cd..08732d2b32 100644
--- a/candle-core/benches/benchmarks/matmul.rs
+++ b/candle-core/benches/benchmarks/matmul.rs
@@ -1,25 +1,49 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
-use criterion::{black_box, criterion_group, Criterion, Throughput};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
use std::time::Instant;
+/// Matmul benchmark shapes covering common GEMM scenarios
+const MATMUL_SHAPES: &[(&str, &[usize], &[usize])] = &[
+ // Original GEMV test
+ ("gemv", &[1, 1, 2048], &[1, 2048, 2048]),
+ // 4D Attention scenarios (multi-head attention)
+ ("attn_4d_small", &[484, 6, 144, 32], &[484, 6, 32, 144]),
+ ("attn_4d_large", &[121, 24, 144, 32], &[121, 24, 32, 144]),
+ // Square matrix tests
+ ("square_512", &[512, 512], &[512, 512]),
+ ("square_1024", &[1024, 1024], &[1024, 1024]),
+ // 3D Batch matmul (attention patterns)
+ ("batch_1000", &[1000, 144, 32], &[1000, 32, 144]),
+ // 2D Linear layer scenarios (transformer FFN)
+ ("linear_large", &[17424, 768], &[768, 3072]),
+];
+
fn run(a: &Tensor, b: &Tensor) {
- a.matmul(&b.t().unwrap()).unwrap();
+ a.broadcast_matmul(b).unwrap();
}
-fn run_bench(c: &mut Criterion, device: &Device) {
- let b = 1;
- let m = 1;
- let n = 2048;
- let k = 2048;
+fn calculate_flops(shape_a: &[usize], shape_b: &[usize]) -> usize {
+ let batch: usize = shape_a
+ .iter()
+ .take(shape_a.len().saturating_sub(2))
+ .product();
+ let batch = if batch == 0 { 1 } else { batch };
+ let m = shape_a[shape_a.len() - 2];
+ let k = shape_a[shape_a.len() - 1];
+ let n = shape_b[shape_b.len() - 1];
+ 2 * batch * m * k * n
+}
+fn run_bench(c: &mut Criterion, device: &Device, name: &str, shape_a: &[usize], shape_b: &[usize]) {
let dtype = DType::F32;
- let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
- let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
+ let lhs = Tensor::zeros(shape_a, dtype, device).unwrap();
+ let rhs = Tensor::zeros(shape_b, dtype, device).unwrap();
- let flops = b * m * n * k;
+ let flops = calculate_flops(shape_a, shape_b);
- let mut group = c.benchmark_group(device.bench_name("matmul"));
+ let mut group = c.benchmark_group(device.bench_name(format!("matmul_{name}")));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
@@ -37,7 +61,9 @@ fn run_bench(c: &mut Criterion, device: &Device) {
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
- run_bench(c, &device);
+ for (name, shape_a, shape_b) in MATMUL_SHAPES {
+ run_bench(c, &device, name, shape_a, shape_b);
+ }
}
}
diff --git a/candle-core/benches/benchmarks/matmul_quantized.rs b/candle-core/benches/benchmarks/matmul_quantized.rs
new file mode 100644
index 0000000000..de3eaa6162
--- /dev/null
+++ b/candle-core/benches/benchmarks/matmul_quantized.rs
@@ -0,0 +1,423 @@
+use std::time::{Duration, Instant};
+
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{D, DType, Device, Module, Tensor, quantized};
+use candle_core::quantized::{GgmlDType, QMatMul};
+
+use criterion::{criterion_group, BenchmarkId, Criterion, Throughput};
+use std::hint::black_box;
+
+const GGLM_TYPE: candle_core::quantized::GgmlDType = candle_core::quantized::GgmlDType::Q8_1;
+
+fn run(a: &Tensor, b: &QMatMul) {
+ b.forward(a).unwrap();
+}
+
+fn bench_impl(b : &mut criterion::Bencher<'_, criterion::measurement::WallTime>, lhs : &Tensor, rhs : &candle_core::quantized::QMatMul, device : &Device){
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _ in 0..iters {
+ run(black_box(lhs), black_box(rhs));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+}
+
+fn test_matmul(
+ device: &Device,
+ group: &mut criterion::BenchmarkGroup,
+ bmnk: (usize, usize, usize, usize),
+ size: usize,
+ multiple_sizes: bool,
+ tp: (bool, bool),
+ typ : quantized::GgmlDType
+) {
+ let (_, m, n, k) = bmnk;
+ let (tpa, tpb) = tp;
+ let dtype = DType::F32;
+
+ let b = 1;
+ let lhs = if tpa {
+ Tensor::zeros((k, m), dtype, device)
+ .unwrap()
+ .transpose(D::Minus1, D::Minus2)
+ .unwrap()
+ } else {
+ Tensor::zeros((m, k), dtype, device).unwrap()
+ };
+
+ let rhs = if tpb {
+ Tensor::zeros((n, k), dtype, device)
+ .unwrap()
+ .transpose(D::Minus1, D::Minus2)
+ .unwrap()
+ } else {
+ Tensor::zeros((k, n), dtype, device).unwrap()
+ };
+
+ let rhs = quantized::QTensor::quantize(&rhs, typ).unwrap();
+ let rhs = quantized::QMatMul::from_qtensor(rhs).unwrap();
+
+ let flops = b * m * n * k;
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.measurement_time(Duration::from_millis(250));
+ group.sample_size(32);
+ group.warm_up_time(Duration::from_secs_f32(0.25));
+ if device.is_wgpu() {
+ #[cfg(feature = "wgpu")]
+ if let Device::Wgpu(wgpu) = device {
+ {
+ // const TILE_SIZES: &[u32] = &[32, 64, 128];
+ // const WPT_VALUES: &[u32] = &[2, 4, 8, 16];
+
+ const TILE_SIZES: &[u32] = &[32];
+ const WPT_VALUES: &[u32] = &[16];
+
+ let tid_size = match typ{
+ GgmlDType::Q4_0 | GgmlDType::Q4_1 => 4,
+ GgmlDType::Q5_0 | GgmlDType::Q5_1 => 4,
+ GgmlDType::Q8_0 | GgmlDType::Q8_1 => 8,
+ _ => panic!()
+ };
+
+ use candle_core::wgpu::wgpu_functions::matmul::sgemm::GenericDynamicMatmulShaderSettings;
+
+ let mut run_bench = |func_name : String|{
+ if multiple_sizes {
+ group.bench_with_input(
+ BenchmarkId::new(func_name.clone(), size),
+ &size,
+ |b, _| bench_impl(b, &lhs, &rhs, device),
+ );
+ } else {
+ group.bench_function(func_name, |b| bench_impl(b, &lhs, &rhs, device),);
+ }
+ };
+
+ //test naive
+ {
+ wgpu.inner_device().set_extension(candle_core::wgpu::QuantizedMatmulAlgorithm::Naive);
+
+ let func_name = device.bench_name(format!(
+ "matmul_naive_{:?}_{}{}",
+ typ,
+ if tpa { "_tA" } else { "" },
+ if tpb { "_tB" } else { "" },
+ ));
+
+ run_bench(func_name);
+ }
+
+
+ //test tiled
+ {
+ // const TILE_SIZES: &[u32] = &[32, 64, 128, 256, 512];
+ // const WPT_VALUES: &[u32] = &[2, 4, 8, 16, 32, 64, 128];
+ const TILE_SIZES: &[u32] = &[32];
+ const WPT_VALUES: &[u32] = &[32];
+ let tile_m = 1;
+ let wptm = 1;
+ for &tile_n in TILE_SIZES {
+ for &tile_k in TILE_SIZES {
+ for &wptn in WPT_VALUES {
+ use candle_core::wgpu::wgpu_functions::matmul::sgemm::{GenericMatmulSettings, StrideOptimization};
+ let widthb = match typ{
+ GgmlDType::Q4_0 | GgmlDType::Q4_1 => 8,
+ GgmlDType::Q5_0 | GgmlDType::Q5_1 => 8,
+ GgmlDType::Q8_0 | GgmlDType::Q8_1 => 4,
+ _ => panic!()
+ };
+ let wptk = widthb;
+ let threads_per_k = tile_k / wptk;
+ let threads = (tile_m / wptm) * (tile_n / wptn) * threads_per_k;
+ if threads == 0{
+ continue;
+ }
+ if tile_n % threads != 0 {
+ continue;
+ }
+ if tile_n + tile_k >= 512 + 256{
+ continue;
+ }
+ let size_areg = match typ{
+ GgmlDType::Q4_0 => 8,
+ GgmlDType::Q5_0 => 12,
+ GgmlDType::Q8_0 => 6,
+ GgmlDType::Q4_1 => 8,
+ GgmlDType::Q5_1 => 8,
+ GgmlDType::Q8_1 => 4,
+ _ => panic!()
+ };
+
+ if (4 == widthb && size_areg == 4) || (8 == widthb && size_areg == 8){
+
+
+ }
+ else if tile_k % (threads * 4) != 0{
+ continue;
+ }
+
+ if threads >= 128{
+ continue;
+ }
+
+ wgpu.inner_device().set_extension(candle_core::wgpu::QuantizedMatmulAlgorithm::Some(GenericDynamicMatmulShaderSettings::new_tiled_small(
+ GenericMatmulSettings::new(
+ tile_m,
+ tile_n,
+ tile_k,
+ StrideOptimization::None,
+ StrideOptimization::None,
+ ),
+ wptm,
+ wptn,
+ false,
+ )));
+
+ let func_name = device.bench_name(format!(
+ "matmul_tiled_small_{:?}({},{},{})_wptm{}_wptn{}{}{}",
+ typ,
+ tile_m, tile_n, tile_k,
+ wptm,
+ wptn,
+ if tpa { "_tA" } else { "" },
+ if tpb { "_tB" } else { "" },
+ ));
+
+ run_bench(func_name);
+
+ }
+ }
+ }
+ }
+
+ //sgemm qunatized
+ {
+ let tile_m = 1;
+ let wptm = 1;
+ for &tile_n in TILE_SIZES {
+ for &tile_k in TILE_SIZES {
+ for &wptn in WPT_VALUES {
+ for &wont_use_load_a in &[false, true] {
+ use candle_core::wgpu::wgpu_functions::matmul::sgemm::{GenericMatmulSettings, StrideOptimization};
+
+ let threads = (tile_m / wptm) * (tile_n / wptn);
+
+ if threads % 8 != 0{
+ //continue;
+ }
+
+ let lpta = (tile_k*tile_m)/threads;
+ if lpta % 4 != 0{
+ continue;
+ }
+
+ if threads % tid_size != 0{
+ continue;
+ }
+
+ if (tile_k * tile_m) % threads != 0{
+ continue;
+ }
+
+ if (tile_n) % threads != 0{
+ continue;
+ }
+
+ let lptb = (tile_k*tile_n)/threads;
+ if lptb % 4 != 0{
+ //continue;
+ }
+ if threads > 256{
+ //continue;
+ }
+ if tile_k == 128 && (tile_m == 128 || tile_n == 128){
+ //continue;
+ }
+
+ if tile_k + tile_m + tile_n >= 256 {
+ continue;
+ }
+
+ //skip small threads as there are prob not the most performant solution
+ if threads <= 16{ //at least 32 threads
+ //continue;
+ }
+
+ wgpu.inner_device().set_extension(candle_core::wgpu::QuantizedMatmulAlgorithm::Some(GenericDynamicMatmulShaderSettings::new_with_a(
+ GenericMatmulSettings::new(
+ tile_m,
+ tile_n,
+ tile_k,
+ StrideOptimization::None,
+ StrideOptimization::None,
+ ),
+ wptm,
+ wptn,
+ false,
+ wont_use_load_a
+ )));
+
+ let func_name = device.bench_name(format!(
+ "matmul_sgemm_{:?}({},{},{})_wptm{}_wptn{}{}{}_wont_load_a{wont_use_load_a}",
+ typ,
+ tile_m, tile_n, tile_k,
+ wptm,
+ wptn,
+ if tpa { "_tA" } else { "" },
+ if tpb { "_tB" } else { "" },
+ ));
+
+ run_bench(func_name);
+ }
+ }
+ }
+ }
+ }
+
+ // //sgemm qunatized 2048x2048 * 2048*2048
+ // {
+ // let tile_m = 1;
+ // let tile_n = 1;
+ // let tile_k = 1;
+ // let wptm = 1;
+ // let wptn = 1;
+ // // for &tile_m in TILE_SIZES {
+ // for &tile_n in TILE_SIZES {
+ // for &tile_k in TILE_SIZES {
+ // //for &wptm in WPT_VALUES {
+ // for &wptn in WPT_VALUES {
+ // use candle_core::wgpu::wgpu_functions::matmul::sgemm::{GenericMatmulSettings, StrideOptimization};
+
+ // let threads = (tile_m / wptm) * (tile_n / wptn);
+
+ // if threads % 8 != 0{
+ // //continue;
+ // }
+
+ // let lpta = (tile_k*tile_m)/threads;
+ // if lpta % 4 != 0{
+ // //continue;
+ // }
+ // let lptb = (tile_k*tile_n)/threads;
+ // if lptb % 4 != 0{
+ // //continue;
+ // }
+ // if threads > 256{
+ // //continue;
+ // }
+ // if(tile_k == 128 && (tile_m == 128 || tile_n == 128)){
+ // //continue;
+ // }
+
+ // if tile_k + tile_m + tile_n >= 256 {
+ // //continue;
+ // }
+
+ // //skip small threads as there are prob not the most performant solution
+ // if threads <= 16{ //at least 32 threads
+ // //continue;
+ // }
+
+ // let mut matmul_alg = wgpu.quantized_matmul_alg.lock().unwrap();
+ // *matmul_alg = candle_core::wgpu::QuantizedMatmulAlgorithm::Some(GenericDynamicMatmulShaderSettings::new_tiled_small(
+ // GenericMatmulSettings::new(
+ // tile_m,
+ // tile_n,
+ // tile_k,
+ // StrideOptimization::None,
+ // StrideOptimization::None,
+ // ),
+ // wptm,
+ // wptn,
+ // false,
+ // ));
+ // drop(matmul_alg); // release lock early
+
+ // let func_name = device.bench_name(format!(
+ // "matmul_tile_{:?}({},{},{})_wptm{}_wptn{}{}{}",
+ // typ,
+ // tile_m, tile_n, tile_k,
+ // wptm,
+ // wptn,
+ // if tpa { "_tA" } else { "" },
+ // if tpb { "_tB" } else { "" },
+ // ));
+
+ // run_bench(func_name);
+ // }
+ // //}
+ // }
+ // }
+ // // }
+ // }
+ }
+ }
+ } else {
+ let func_name = device.bench_name(format!(
+ "matmul{}{}",
+ if tpa { "_tranposedA" } else { "" },
+ if tpb { "_tranposedB" } else { "" }
+ ));
+ if multiple_sizes {
+ group.bench_with_input(BenchmarkId::new(func_name, size), &size, |b, _| bench_impl(b, &lhs, &rhs, device));
+ } else {
+ group.bench_function(func_name, |b| bench_impl(b, &lhs, &rhs, device));
+ }
+ }
+}
+
+#[allow(dead_code)]
+fn test_functions(
+ device: &Device,
+ group: &mut criterion::BenchmarkGroup,
+ fm: impl Fn(usize) -> usize,
+) {
+ let sizes = vec![2050usize, 2048, 1032, 1024, 528, 512, 128, 120, 32, 16].into_iter();
+ for size in sizes {
+ test_matmul(
+ device,
+ group,
+ (1, fm(size), size, size),
+ size,
+ true,
+ (false, false),
+ GGLM_TYPE
+ );
+ }
+}
+
+fn test_matmul_group(c: &mut Criterion, bmnk: (usize, usize, usize, usize), dtype : GgmlDType) {
+ let (b, m, n, k) = bmnk;
+ let handler = BenchDeviceHandler::new().unwrap();
+
+ let mut group = c.benchmark_group(format!("matmul_{b}x({m}x{k} * {k}x{n})"));
+ for device in handler.devices.iter() {
+ test_matmul(
+ device,
+ &mut group,
+ bmnk,
+ 1,
+ false,
+ (false, false),
+ dtype
+ );
+ }
+ group.finish();
+
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ for dtype in [
+ GgmlDType::Q4_0,
+ GgmlDType::Q4_1, GgmlDType::Q5_0, GgmlDType::Q5_1,
+ GgmlDType::Q8_0, GgmlDType::Q8_1
+ ]{
+ test_matmul_group(c, (1, 1, 4096, 4096), dtype);
+ }
+
+ //test_matmul_group(c, (1, 2048, 2048, 2048), GGLM_TYPE);
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/matmul_wgpu.rs b/candle-core/benches/benchmarks/matmul_wgpu.rs
new file mode 100644
index 0000000000..d35f184209
--- /dev/null
+++ b/candle-core/benches/benchmarks/matmul_wgpu.rs
@@ -0,0 +1,302 @@
+use std::time::{Duration, Instant};
+
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{DType, Device, Tensor, D};
+
+#[cfg(feature = "wgpu")]
+use candle_core::wgpu::MatmulAlgorithm;
+
+use criterion::{criterion_group, BenchmarkId, Criterion, Throughput};
+use std::hint::black_box;
+
+fn run(a: &Tensor, b: &Tensor) {
+ a.matmul(b).unwrap();
+}
+
+fn test_matmul(
+ device: &Device,
+ group: &mut criterion::BenchmarkGroup,
+ bmnk: (usize, usize, usize, usize),
+ _is_small_line: bool,
+ size: usize,
+ multiple_sizes: bool,
+ tp: (bool, bool),
+) {
+ let (b, m, n, k) = bmnk;
+ let (tpa, tpb) = tp;
+ let dtype = DType::F32;
+
+ let lhs = if tpa {
+ Tensor::zeros((b, k, m), dtype, device)
+ .unwrap()
+ .transpose(D::Minus1, D::Minus2)
+ .unwrap()
+ } else {
+ Tensor::zeros((b, m, k), dtype, device).unwrap()
+ };
+
+ let rhs = if tpb {
+ Tensor::zeros((b, n, k), dtype, device)
+ .unwrap()
+ .transpose(D::Minus1, D::Minus2)
+ .unwrap()
+ } else {
+ Tensor::zeros((b, k, n), dtype, device).unwrap()
+ };
+
+ let flops = b * m * n * k;
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.measurement_time(Duration::from_secs(1));
+ group.sample_size(32);
+ group.warm_up_time(Duration::from_secs_f32(0.25));
+ if device.is_wgpu() {
+ #[cfg(feature = "wgpu")]
+ if let Device::Wgpu(wgpu) = device {
+ {
+ let mut algs;
+ algs = vec![
+ //MatmulAlgorithm::Matmul64_64_8_8,
+ //MatmulAlgorithm::Matmul64_64_4_8,
+ //MatmulAlgorithm::Matmul64_64,
+ //MatmulAlgorithm::Matmul32_64,
+ //MatmulAlgorithm::Matmul32_64B,
+ //MatmulAlgorithm::Matmul1_64B,
+ //MatmulAlgorithm::Matmul32_32,
+ //MatmulAlgorithm::Matmul16_16,
+ //MatmulAlgorithm::Matmul24_24,
+ //MatmulAlgorithm::Matmul24_48,
+ //MatmulAlgorithm::MatmulX,
+ MatmulAlgorithm::Matmul7,
+ MatmulAlgorithm::Matmul1,
+ MatmulAlgorithm::Matmul1M1,
+ MatmulAlgorithm::MatmulX,
+ ];
+
+ if _is_small_line {
+ algs.push(MatmulAlgorithm::Matmul1_64);
+ algs.push(MatmulAlgorithm::Matmul1_64B);
+ }
+
+ for alg in algs {
+ wgpu.inner_device().set_extension(alg.clone());
+
+ let func_name = device.bench_name(format!(
+ "matmul_{:?}{}{}",
+ alg,
+ if tpa { "_tranposedA" } else { "" },
+ if tpb { "_tranposedB" } else { "" }
+ ));
+ if multiple_sizes {
+ group.bench_with_input(
+ BenchmarkId::new(func_name.clone(), size),
+ &size,
+ |b, _| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _ in 0..iters {
+ run(black_box(&lhs), black_box(&rhs));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ },
+ );
+ } else {
+ group.bench_function(func_name, |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _ in 0..iters {
+ run(black_box(&lhs), black_box(&rhs));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ }
+ }
+ }
+ }
+ } else {
+ let func_name = device.bench_name(format!(
+ "matmul{}{}",
+ if tpa { "_tranposedA" } else { "" },
+ if tpb { "_tranposedB" } else { "" }
+ ));
+ if multiple_sizes {
+ group.bench_with_input(BenchmarkId::new(func_name, size), &size, |b, _| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&lhs), black_box(&rhs));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ } else {
+ group.bench_function(func_name, |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _ in 0..iters {
+ run(black_box(&lhs), black_box(&rhs));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ }
+ }
+}
+
+#[allow(dead_code)]
+fn test_functions(
+ device: &Device,
+ group: &mut criterion::BenchmarkGroup,
+ fm: impl Fn(usize) -> usize,
+) {
+ let sizes = vec![2050usize, 2048, 1032, 1024, 528, 512, 128, 120, 32, 16].into_iter();
+ for size in sizes {
+ test_matmul(
+ device,
+ group,
+ (1, fm(size), size, size),
+ fm(2) == 1,
+ size,
+ true,
+ (false, false),
+ );
+ }
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+
+ // let mut group = c.benchmark_group("matmul_m_1");
+ // for device in handler.devices.iter() {
+ // test_functions(device, &mut group, |_| 1);
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_full");
+ // for device in handler.devices.iter() {
+ // test_functions(device, &mut group, |size| size);
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_(2048x2048 * 2048x2048)");
+ // for device in handler.devices.iter() {
+ // test_matmul(
+ // device,
+ // &mut group,
+ // (1, 2048, 2048, 2048),
+ // false,
+ // 1,
+ // false,
+ // (false, false),
+ // );
+ // test_matmul(
+ // device,
+ // &mut group,
+ // (1, 2048, 2048, 2048),
+ // false,
+ // 1,
+ // false,
+ // (true, false),
+ // );
+ // // test_matmul(device, &mut group, (1, 2048, 2048, 2048), false, 1, false, (false, true));
+ // // test_matmul(device, &mut group, (1, 2048, 2048, 2048), false, 1, false, (true, true));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_(2050x2050 * 2050x2050)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (1, 2050, 2050, 2050), false, 1, false, (false, false));
+ // // test_matmul(device, &mut group, (1, 2048, 2048, 2048), false, 1, false, (true, false));
+ // // test_matmul(device, &mut group, (1, 2048, 2048, 2048), false, 1, false, (false, true));
+ // // test_matmul(device, &mut group, (1, 2048, 2048, 2048), false, 1, false, (true, true));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_2*(1x9 * 9x576)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (2, 1, 576, 9), true, 1, false, (false, false));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_(32x2304 * 2304x5120)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (1, 32, 5120, 2304), false, 1, false, (false, false));
+ // // test_matmul(device, &mut group, (1, 32, 5120, 2304), false, 1, false, (true, false));
+ // // test_matmul(device, &mut group, (1, 32, 5120, 2304), false, 1, false, (false, true));
+ // // test_matmul(device, &mut group, 81, 32, 5120, 2304), false, 1, false, (true, true));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_2*(1x2048 * 2048x5632)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (2, 1, 5632, 2048), true, 1, false, (false, false));
+ // test_matmul(device, &mut group, (2, 1, 5632, 2048), true, 1, false, (true, false));
+ // test_matmul(device, &mut group, (2, 1, 5632, 2048), true, 1, false, (false, true));
+ // test_matmul(device, &mut group, (2, 1, 5632, 2048), true, 1, false, (true, true));
+ // }
+ //group.finish();
+
+ let mut group = c.benchmark_group("matmul_32*(1x767 * 767x128)");
+ for device in handler.devices.iter() {
+ test_matmul(device, &mut group, (32, 1, 128, 767), true, 1, false, (false, false));
+ }
+ group.finish();
+
+ // let mut group = c.benchmark_group("matmul_(64x2304 * 2304x5120)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (1, 64, 5120, 2304), false, 1, false, (false, false));
+ // // test_matmul(device, &mut group, (1, 64, 5120, 2304), false, 1, false, (true, false));
+ // // test_matmul(device, &mut group, (1, 64, 5120, 2304), false, 1, false, (false, true));
+ // // test_matmul(device, &mut group, (1, 64, 5120, 2304), false, 1, false, (true, true));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_(24x1536 * 1536x6144)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (1, 24, 6144, 1536), false, 1, false, (false, false));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_2*(653x1536 * 1536x1536)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (2, 653, 1536, 1536), false, 1, false, (false, false));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_32*(32x2304 * 2304x5120)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (32, 32, 5120, 2304), false, 1, false, (false, false));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_(1101x1280 * 1280x1280)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (1, 1101, 1280, 1280), false, 1, false, (false, false));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_10*(4096x64 * 64x4173)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (10, 4096, 4173, 64), false, 1, false, (false, false));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_20*(1024x64 * 64x1101)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (20, 1024, 1101, 64), false, 1, false, (false, false));
+ // }
+ // group.finish();
+
+ // let mut group = c.benchmark_group("matmul_64*(64x1664 * 1664x2560)");
+ // for device in handler.devices.iter() {
+ // test_matmul(device, &mut group, (64, 64, 2560, 1664), false, 1, false, (false, false)alse);
+ // }
+ // group.finish();
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs
index 579c5f3f0b..6e9d0f0926 100644
--- a/candle-core/benches/benchmarks/mod.rs
+++ b/candle-core/benches/benchmarks/mod.rs
@@ -1,12 +1,19 @@
pub(crate) mod affine;
+pub(crate) mod broadcast;
+pub(crate) mod binary;
+pub(crate) mod conv2d;
pub(crate) mod conv_transpose2d;
+pub(crate) mod copy;
pub(crate) mod matmul;
+pub(crate) mod matmul_wgpu;
+pub(crate) mod matmul_quantized;
pub(crate) mod qmatmul;
pub(crate) mod random;
+pub(crate) mod reduce;
pub(crate) mod unary;
pub(crate) mod where_cond;
-use candle_core::{Device, Result};
+use candle_core::{backend::BackendDevice, Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
@@ -20,16 +27,20 @@ impl BenchDevice for Device {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
- return Ok(device.synchronize()?);
+ {
+ use candle_core::backend::BackendDevice;
+ return Ok(device.synchronize()?);
+ }
#[cfg(not(feature = "cuda"))]
- panic!("Cuda device without cuda feature enabled: {:?}", device)
+ panic!("Cuda device without cuda feature enabled: {device:?}")
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
- return Ok(device.wait_until_completed()?);
+ return device.wait_until_completed();
#[cfg(not(feature = "metal"))]
- panic!("Metal device without metal feature enabled: {:?}", device)
+ panic!("Metal device without metal feature enabled: {device:?}")
}
+ Device::Wgpu(device) => Ok(device.synchronize()?),
}
}
@@ -47,6 +58,7 @@ impl BenchDevice for Device {
}
Device::Cuda(_) => format!("cuda_{}", name.into()),
Device::Metal(_) => format!("metal_{}", name.into()),
+ Device::Wgpu(_) => format!("wgpu_{}", name.into()),
}
}
}
@@ -62,8 +74,11 @@ impl BenchDeviceHandler {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
devices.push(Device::new_cuda(0)?);
+ } else if cfg!(feature = "wgpu") {
+ devices.push(Device::new_wgpu(0)?);
+ } else {
+ devices.push(Device::Cpu);
}
- devices.push(Device::Cpu);
Ok(Self { devices })
}
}
diff --git a/candle-core/benches/benchmarks/qmatmul.rs b/candle-core/benches/benchmarks/qmatmul.rs
index 4d34588b36..97e10c8c15 100644
--- a/candle-core/benches/benchmarks/qmatmul.rs
+++ b/candle-core/benches/benchmarks/qmatmul.rs
@@ -3,7 +3,8 @@ use candle_core::{
quantized::{self, GgmlDType, QMatMul},
Device, Module, Tensor,
};
-use criterion::{black_box, criterion_group, Criterion, Throughput};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
use std::time::Instant;
fn run(matmul: &QMatMul, x: &Tensor) {
@@ -31,7 +32,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
let flops = b * m * n * k;
- let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
+ let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{dtype:?}")));
group.sample_size(200);
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
@@ -64,6 +65,9 @@ fn criterion_benchmark(c: &mut Criterion) {
GgmlDType::Q5K,
GgmlDType::Q6K,
] {
+ if device.is_wgpu() && dtype == GgmlDType::F16{
+ continue;
+ }
run_bench(c, &device, dtype);
}
}
diff --git a/candle-core/benches/benchmarks/random.rs b/candle-core/benches/benchmarks/random.rs
index 22c60ef18c..365e051c0e 100644
--- a/candle-core/benches/benchmarks/random.rs
+++ b/candle-core/benches/benchmarks/random.rs
@@ -1,6 +1,7 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
-use criterion::{black_box, criterion_group, Criterion, Throughput};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
use std::time::Instant;
fn rand_uniform(a: &Tensor) {
diff --git a/candle-core/benches/benchmarks/reduce.rs b/candle-core/benches/benchmarks/reduce.rs
new file mode 100644
index 0000000000..30af7cd2a1
--- /dev/null
+++ b/candle-core/benches/benchmarks/reduce.rs
@@ -0,0 +1,175 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{DType, Device, Tensor};
+use criterion::{criterion_group, Criterion, Throughput};
+use half::{bf16, f16};
+use std::hint::black_box;
+use std::time::Instant;
+
+fn run_sum(a: &Tensor) {
+ a.sum_keepdim(2).unwrap();
+}
+fn run_arg_min(a: &Tensor) {
+ a.argmin_keepdim(2).unwrap();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ let (lo, up) = (-1000.0f32, 1000.0f32);
+ for device in handler.devices {
+ run_reduce(c, &device, (lo, up), false);
+ if device.is_dtype_available(DType::F16){
+ run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
+ }
+ if device.is_dtype_available(DType::BF16){
+ run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
+ }
+
+ run_arg_reduce(c, &device, (lo, up), false);
+ if device.is_dtype_available(DType::F16){
+ run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
+ }
+ if device.is_dtype_available(DType::BF16){
+ run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
+ }
+
+ run_reduce(c, &device, (lo, up), true);
+ if device.is_dtype_available(DType::F16){
+ run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
+ }
+ if device.is_dtype_available(DType::BF16){
+ run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
+ }
+
+ run_arg_reduce(c, &device, (lo, up), true);
+ if device.is_dtype_available(DType::F16){
+ run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
+ }
+ if device.is_dtype_available(DType::BF16){
+ run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
+ }
+ }
+}
+
+fn run_reduce(
+ c: &mut Criterion,
+ device: &Device,
+ (lo, up): (T, T),
+ strided: bool,
+) {
+ let b = 1;
+ let m = 1024;
+ let k = 1024;
+
+ let a = if strided {
+ Tensor::rand(lo, up, (b, m, k), device)
+ .unwrap()
+ .transpose(0, 2)
+ .unwrap()
+ } else {
+ Tensor::rand(lo, up, (b, m, k), device).unwrap()
+ };
+
+ let flops = b * m * k * T::DTYPE.size_in_bytes();
+
+ let name = match T::DTYPE {
+ DType::F32 => {
+ if strided {
+ "reduce_f32_strided"
+ } else {
+ "reduce_f32"
+ }
+ }
+ DType::F16 => {
+ if strided {
+ "reduce_f16_strided"
+ } else {
+ "reduce_f16"
+ }
+ }
+ DType::BF16 => {
+ if strided {
+ "reduce_bf16_strided"
+ } else {
+ "reduce_bf16"
+ }
+ }
+ _ => "unknown",
+ };
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run_sum(black_box(&a));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn run_arg_reduce(
+ c: &mut Criterion,
+ device: &Device,
+ (lo, up): (T, T),
+ strided: bool,
+) {
+ let b = 1;
+ let m = 1024;
+ let k = 1024;
+
+ let a = if strided {
+ Tensor::rand(lo, up, (b, m, k), device)
+ .unwrap()
+ .transpose(0, 2)
+ .unwrap()
+ } else {
+ Tensor::rand(lo, up, (b, m, k), device).unwrap()
+ };
+
+ let flops = b * m * k * T::DTYPE.size_in_bytes();
+
+ let name = match T::DTYPE {
+ DType::F32 => {
+ if strided {
+ "arg_reduce_f32_strided"
+ } else {
+ "arg_reduce_f32"
+ }
+ }
+ DType::F16 => {
+ if strided {
+ "arg_reduce_f16_strided"
+ } else {
+ "arg_reduce_f16"
+ }
+ }
+ DType::BF16 => {
+ if strided {
+ "arg_reduce_bf16_strided"
+ } else {
+ "arg_reduce_bf16"
+ }
+ }
+ _ => "unknown",
+ };
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run_arg_min(black_box(&a));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/unary.rs b/candle-core/benches/benchmarks/unary.rs
index 9efd75093d..425433ea7d 100644
--- a/candle-core/benches/benchmarks/unary.rs
+++ b/candle-core/benches/benchmarks/unary.rs
@@ -1,9 +1,10 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
-use criterion::{black_box, criterion_group, Criterion, Throughput};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
use std::time::Instant;
-fn run(a: &Tensor) {
+fn run_sqrt(a: &Tensor) {
a.sqrt().unwrap();
}
@@ -27,7 +28,46 @@ fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
- run(black_box(&tensor));
+ run_sqrt(black_box(&tensor));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn run_cast(a: &Tensor, dtype: DType) {
+ a.to_dtype(dtype).unwrap();
+}
+
+fn run_cast_benchmark(
+ c: &mut Criterion,
+ device: &Device,
+ dtype: DType,
+ to_dtype: DType,
+ name: &str,
+) {
+ let b = 1;
+ let m = 1024;
+ let k = 1024;
+
+ let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap()
+ .reshape((b, m, k))
+ .unwrap();
+
+ let flops = b * m * k * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run_cast(black_box(&tensor), black_box(to_dtype));
}
device.sync().unwrap();
start.elapsed()
@@ -40,8 +80,21 @@ fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [DType::F32, DType::BF16, DType::F16] {
- let name = format!("sqrt_{:?}", dtype);
- run_unary_benchmark(c, &device, dtype, &name);
+ let to_dtype = if matches!(dtype, DType::F32) {
+ DType::F16
+ } else {
+ DType::F32
+ };
+ let name = format!("cast_{}_{}", dtype.as_str(), to_dtype.as_str());
+ if device.is_dtype_available(dtype) && device.is_dtype_available(to_dtype) {
+ run_cast_benchmark(c, &device, dtype, to_dtype, &name);
+ }
+ }
+ for dtype in [DType::F32, DType::BF16, DType::F16] {
+ let name = format!("sqrt_{dtype:?}");
+ if device.is_dtype_available(dtype) {
+ run_unary_benchmark(c, &device, dtype, &name);
+ }
}
}
}
diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs
index 0e91f656fc..1392fee107 100644
--- a/candle-core/benches/benchmarks/where_cond.rs
+++ b/candle-core/benches/benchmarks/where_cond.rs
@@ -1,6 +1,7 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
-use criterion::{black_box, criterion_group, Criterion, Throughput};
+use criterion::{criterion_group, Criterion, Throughput};
+use std::hint::black_box;
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
@@ -17,15 +18,31 @@ const fn create_cond_arr() -> [u8; N] {
arr
}
+const fn create_cond_arr_u32() -> [u32; N] {
+ let mut arr = [0u32; N];
+ let mut i = 0;
+ while i < N {
+ arr[i] = (i % 2) as u32;
+ i += 1;
+ }
+ arr
+}
+
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;
-const DATA: [u8; SIZE] = create_cond_arr::();
+static DATA: [u8; SIZE] = create_cond_arr::();
+static DATA_U32: [u32; SIZE] = create_cond_arr_u32::();
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
- let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
+ let tensor: Tensor = if device.is_wgpu() {
+ Tensor::from_slice(DATA_U32.as_slice(), (B, M, K), device).unwrap()
+ } else {
+ Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap()
+ };
+
let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
@@ -56,8 +73,12 @@ fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
- run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
- run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
+ if d.is_dtype_available(DType::BF16) {
+ run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
+ }
+ if d.is_dtype_available(DType::F16) {
+ run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
+ }
}
}
diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs
index fe15187b5a..e991403441 100644
--- a/candle-core/examples/basics.rs
+++ b/candle-core/examples/basics.rs
@@ -9,9 +9,12 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
- let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
+ let b = Tensor::new(&[[88.0f32], [99.0]], &Device::Cpu)?;
let new_a = a.slice_scatter(&b, 1, 2)?;
assert_eq!(a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
- assert_eq!(new_a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
+ assert_eq!(
+ new_a.to_vec2::()?,
+ [[0.0, 1.0, 88.0], [3.0, 4.0, 99.0]]
+ );
Ok(())
}
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs
index 9af1b006e3..4eadcdeb82 100644
--- a/candle-core/examples/cuda_basics.rs
+++ b/candle-core/examples/cuda_basics.rs
@@ -6,28 +6,18 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
-
+// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
- let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
- .to_dtype(candle_core::DType::BF16)?;
- candle_core::cuda::set_gemm_reduced_precision_f32(false);
- candle_core::cuda::set_gemm_reduced_precision_bf16(false);
- let _x1 = x.matmul(&x)?;
- drop(_x1);
- let start_time = std::time::Instant::now();
- let _x1 = x.matmul(&x)?;
- device.synchronize()?;
- println!("fp32: {:?}", start_time.elapsed());
- drop(_x1);
- candle_core::cuda::set_gemm_reduced_precision_f32(true);
- candle_core::cuda::set_gemm_reduced_precision_bf16(true);
- let _x1 = x.matmul(&x)?;
- drop(_x1);
- let start_time = std::time::Instant::now();
- let _x1 = x.matmul(&x)?;
- device.synchronize()?;
- println!("tf32: {:?}", start_time.elapsed());
+ let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
+ let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
+ let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
drop(_x1);
+ for _ in 0..20 {
+ let start_time = std::time::Instant::now();
+ let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
+ device.synchronize()?;
+ println!("conv1d: {:?}", start_time.elapsed());
+ }
Ok(())
}
diff --git a/candle-core/examples/cuda_sum_benchmark.rs b/candle-core/examples/cuda_sum_benchmark.rs
index d6d182e8fc..5bd4b4eefe 100644
--- a/candle-core/examples/cuda_sum_benchmark.rs
+++ b/candle-core/examples/cuda_sum_benchmark.rs
@@ -10,7 +10,7 @@ use anyhow::Result;
use candle_core::{Device, Tensor};
fn cos_sin(n: usize, device: &Device) -> Result {
- let thetas: Vec<_> = (0..n).map(|i| (i as f32 / n as f32)).collect();
+ let thetas: Vec<_> = (0..n).map(|i| i as f32 / n as f32).collect();
let xs: Vec<_> = thetas.iter().map(|t| t.cos().abs()).collect();
let ys: Vec<_> = thetas.iter().map(|t| t.sin().abs()).collect();
let xs = Tensor::from_vec(xs, (n, 1), device)?;
diff --git a/candle-core/examples/metal_basics.rs b/candle-core/examples/metal_basics.rs
index f9ff81adc4..3f433d9c26 100644
--- a/candle-core/examples/metal_basics.rs
+++ b/candle-core/examples/metal_basics.rs
@@ -21,7 +21,7 @@ fn main() -> Result<()> {
let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
let x1 = x.add(&x)?;
println!("{x1:?}");
- // This second synchronize ensures that the command buffer gets commited before the end of the
+ // This second synchronize ensures that the command buffer gets committed before the end of the
// capture scope.
device.synchronize()?;
Ok(())
diff --git a/candle-core/examples/wgpu_basics.rs b/candle-core/examples/wgpu_basics.rs
new file mode 100644
index 0000000000..1006da7f80
--- /dev/null
+++ b/candle-core/examples/wgpu_basics.rs
@@ -0,0 +1,148 @@
+use std::borrow::Cow;
+
+use anyhow::Result;
+use candle_core::{backend::BackendStorage, CustomOp1, Device, Tensor};
+use wgpu_compute_layer::{PipelineIndex, ShaderIndex};
+
+//this demonstrates, how a custom wgpu kernel can be used:
+#[derive(Debug)]
+struct MyCustomLoader{}
+
+wgpu_compute_layer::create_loader!(MyCustomLoader);
+
+impl wgpu_compute_layer::ShaderLoader for MyCustomLoader {
+ //define the shader:
+ fn load(&self, _: wgpu_compute_layer::ShaderIndex, _ : &[(&str, String)]) -> Cow<'_, str> {
+ "
+//Binding Order: Dest, Meta, Input1, Input2, Input3
+@group(0) @binding(0)
+var v_dest: array;
+
+@group(0) @binding(1)
+var op_meta : array;
+
+@group(0) @binding(2)
+var v_input1: array;
+
+@compute @workgroup_size(1)
+fn main1() {
+ v_dest[0] = 2 * op_meta[0];
+}
+@compute @workgroup_size(1)
+fn main2() {
+ v_dest[0] = v_input1[0] * op_meta[0];
+}
+ ".into()
+ }
+
+ //define the entry point:
+ fn get_entry_point(&self, index: wgpu_compute_layer::PipelineIndex) -> &str {
+ match index.get_index() {
+ 0 => "main1",
+ 1 => "main2",
+ _ => {
+ todo!()
+ }
+ }
+ }
+}
+
+#[cfg(feature = "wgpu")]
+fn main() -> Result<()> {
+ let device = &Device::new_wgpu(0)?;
+ let wgpu_device= device.as_wgpu_device()?;
+
+ //0. add the custom loader to the device(this must be done only once)
+ wgpu_device.add_wgpu_shader_loader(MyCustomLoader::LOADER_INDEX, || MyCustomLoader {});
+
+ let mut queue = wgpu_device.get_queue();
+ let output_buffer = wgpu_device.alloc_uninit_size(candle_core::DType::U32, 1);
+
+ //1. add optional data for the next shader call
+ queue.add(42);
+
+ //2. define the pipeline to use:
+ let pipeline = queue.get_pipeline(PipelineIndex::new(
+ ShaderIndex::new(MyCustomLoader::LOADER_INDEX, 0),
+ 0,
+ ));
+
+ //3. define the bindgroup to use (defines dest, input buffer and the alignment)
+ let bind_group = wgpu_device.create_bind_group_input0(
+ output_buffer.buffer(),
+ candle_core::DType::U32.into(),
+ );
+
+ //4. add the command to the queue:
+ queue.enqueue_workgroups(pipeline, bind_group, 1, 1, 1, 1);
+
+ let cpu_storage_data = output_buffer.to_cpu_storage()?;
+
+ match cpu_storage_data {
+ candle_core::CpuStorage::U32(vec) => {
+ assert_eq!(vec[0], 42 * 2);
+ }
+ _ => todo!(),
+ }
+
+ let input = Tensor::from_slice(&[17u32], (), device)?;
+ let output = input.apply_op1(CustomExampleOp {})?;
+
+ assert_eq!(output.to_vec0::()?, 17 * 42u32);
+
+ Ok(())
+}
+
+struct CustomExampleOp {}
+
+impl CustomOp1 for CustomExampleOp {
+ fn name(&self) -> &'static str {
+ "CustomExampleOp"
+ }
+
+ fn cpu_fwd(
+ &self,
+ _storage: &candle_core::CpuStorage,
+ _layout: &candle_core::Layout,
+ ) -> candle_core::Result<(candle_core::CpuStorage, candle_core::Shape)> {
+ todo!()
+ }
+
+ fn wgpu_fwd(
+ &self,
+ storage: &candle_core::WgpuStorage,
+ _layout: &candle_core::Layout,
+ ) -> candle_core::Result<(candle_core::WgpuStorage, candle_core::Shape)> {
+ //1. add the custom loader to the device
+ storage
+ .device()
+ .add_wgpu_shader_loader(MyCustomLoader::LOADER_INDEX, || MyCustomLoader {});
+
+ //2. add optional data to the meta - structure
+ let mut queue = storage.device().get_queue();
+ queue.add(42);
+
+ //3. define the pipeline to use:
+ let pipeline = queue.get_pipeline(PipelineIndex::new(
+ ShaderIndex::new(MyCustomLoader::LOADER_INDEX, 0),
+ 1,
+ ));
+
+ let output_buffer = storage.device().alloc_uninit_size(
+ candle_core::DType::U32,
+ 1,
+ );
+
+ //4. define the bindgroup to use (defines dest, input buffer and the alignment)
+ let bind_group = storage.device().create_bind_group_input1(
+ output_buffer.buffer(),
+ storage.buffer(),
+ candle_core::DType::U32.into(),
+ );
+
+ //5. queue the command to the queue:
+ queue.enqueue_workgroups(pipeline, bind_group, 1, 1, 1, 1);
+
+ Ok((output_buffer, ().into()))
+ }
+}
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index afe3e40754..d8ab2b5629 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -1,3 +1,5 @@
+//! Traits to Define Backend Behavior
+//!
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
@@ -67,17 +69,38 @@ pub trait BackendStorage: Sized {
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result;
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result;
+ fn upsample_bilinear2d(
+ &self,
+ _: &Layout,
+ _: usize,
+ _: usize,
+ _: bool,
+ _: Option,
+ _: Option,
+ ) -> Result;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result;
- fn scatter_add(
- &self,
+
+ fn scatter_set(
+ &mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
- ) -> Result;
+ ) -> Result<()>;
+
+ fn scatter_add_set(
+ &mut self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<()>;
+
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result;
fn index_add(
&self,
@@ -111,6 +134,8 @@ pub trait BackendStorage: Sized {
_src_offset: usize,
_dst_offset: usize,
) -> Result<()>;
+
+ fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
}
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
@@ -125,8 +150,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result;
- fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result;
-
/// # Safety
/// This function is unsafe as it doesn't initialize the underlying data store.
/// The caller should ensure that the data is properly initialized as early as possible
@@ -144,6 +167,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result;
fn set_seed(&self, _: u64) -> Result<()>;
+ fn get_current_seed(&self) -> Result;
/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index a556677478..d2310cbe28 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -1,4 +1,4 @@
-/// Methods for backpropagation of gradients.
+//! Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
@@ -32,7 +32,7 @@ impl Tensor {
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.
/// This assumes that the op graph is a DAG.
- fn sorted_nodes(&self) -> Vec<&Tensor> {
+ pub fn sorted_nodes(&self) -> Vec<&Tensor> {
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
// to get around some lifetime limitations.
fn walk<'a>(
@@ -53,6 +53,7 @@ impl Tensor {
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
+ | Op::Scatter(t1, t2, t3, _)
| Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
@@ -117,6 +118,7 @@ impl Tensor {
Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. }
| Op::UpsampleNearest2D { arg: node, .. }
+ | Op::UpsampleBilinear2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
@@ -406,6 +408,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
+ Op::UpsampleBilinear2D { .. } => {
+ crate::bail!("backward not supported for upsample_bilinear2d")
+ }
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
@@ -419,7 +424,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
- Op::ScatterAdd(init, indexes, src, dim) => {
+ Op::Scatter(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
@@ -427,6 +432,16 @@ impl Tensor {
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
+ Op::ScatterAdd(init, indexes, src, dim) => {
+ let init_sum_grad = grads.or_insert(init)?;
+ let mask = init.ones_like()?;
+ let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
+ *init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
+
+ let src_grad = grad.gather(indexes, *dim)?;
+ let src_sum_grad = grads.or_insert(src)?;
+ *src_sum_grad = src_sum_grad.add(&src_grad)?;
+ }
Op::IndexAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
@@ -743,6 +758,11 @@ impl GradStore {
self.0.insert(tensor.id(), grad)
}
+ /// Insert a gradient tensor associated with the given tensor id, returning the previous gradient tensor if it existed
+ pub fn insert_id(&mut self, id: TensorId, grad: Tensor) -> Option {
+ self.0.insert(id, grad)
+ }
+
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 7b3922dd73..115035ef1c 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,3 +1,5 @@
+//! 1D and 2D Convolutions
+//!
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -12,6 +14,7 @@ pub struct ParamsConv1D {
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
+ pub(crate) cudnn_fwd_algo: Option,
}
impl ParamsConv1D {
@@ -52,7 +55,7 @@ impl ParamsConvTranspose1D {
}
}
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
ImplicitPrecompGemm,
@@ -149,6 +152,19 @@ impl Tensor {
stride: usize,
dilation: usize,
groups: usize,
+ ) -> Result {
+ self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
+ }
+
+ /// Applies a 1D convolution over the input tensor.
+ pub fn conv1d_with_algo(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ dilation: usize,
+ groups: usize,
+ cudnn_fwd_algo: Option,
) -> Result {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
@@ -172,6 +188,7 @@ impl Tensor {
padding,
stride,
dilation,
+ cudnn_fwd_algo,
};
if groups == 1 {
self.conv1d_single_group(kernel, ¶ms)
@@ -276,6 +293,18 @@ impl Tensor {
stride: usize,
dilation: usize,
groups: usize,
+ ) -> Result {
+ self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
+ }
+
+ pub fn conv2d_with_algo(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ dilation: usize,
+ groups: usize,
+ cudnn_fwd_algo: Option,
) -> Result {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
@@ -295,7 +324,7 @@ impl Tensor {
padding,
stride,
dilation,
- cudnn_fwd_algo: None,
+ cudnn_fwd_algo,
};
if groups == 1 {
self.conv2d_single_group(kernel, ¶ms)
diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs
index 5ea5612a7c..38e7a7c9a6 100644
--- a/candle-core/src/convert.rs
+++ b/candle-core/src/convert.rs
@@ -93,6 +93,8 @@ from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(i64);
+from_tensor!(i32);
+from_tensor!(i16);
from_tensor!(u32);
from_tensor!(u8);
@@ -130,6 +132,16 @@ impl Tensor {
f.write_u32::(v)?
}
}
+ DType::I16 => {
+ for v in vs.to_vec1::()? {
+ f.write_i16::(v)?
+ }
+ }
+ DType::I32 => {
+ for v in vs.to_vec1::()? {
+ f.write_i32::(v)?
+ }
+ }
DType::I64 => {
for v in vs.to_vec1::()? {
f.write_i64::(v)?
@@ -139,6 +151,15 @@ impl Tensor {
let vs = vs.to_vec1::()?;
f.write_all(&vs)?;
}
+ DType::F8E4M3 => {
+ let vs = vs.to_vec1::()?;
+ for v in vs {
+ f.write_u8(v.to_bits())?
+ }
+ }
+ DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
+ return Err(crate::Error::UnsupportedDTypeForOp(self.dtype(), "write_bytes").bt())
+ }
}
Ok(())
}
diff --git a/candle-core/src/cpu/avx.rs b/candle-core/src/cpu/avx.rs
index 9398a3460a..113fc14ced 100644
--- a/candle-core/src/cpu/avx.rs
+++ b/candle-core/src/cpu/avx.rs
@@ -1,10 +1,10 @@
-use super::{Cpu, CpuF16};
+use super::{Cpu, CpuBF16, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
-use half::f16;
+use half::{bf16, f16};
pub struct CurrentCpu {}
@@ -146,3 +146,82 @@ impl CpuF16 for CurrentCpuF16 {
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
+
+pub struct CurrentCpuBF16 {}
+impl CpuBF16 for CurrentCpuBF16 {
+ type Unit = __m256;
+ type Array = [__m256; ARR];
+
+ const STEP: usize = STEP;
+ const EPR: usize = EPR;
+
+ fn n() -> usize {
+ ARR
+ }
+
+ unsafe fn zero() -> Self::Unit {
+ _mm256_setzero_ps()
+ }
+
+ unsafe fn zero_array() -> Self::Array {
+ [Self::zero(); ARR]
+ }
+
+ unsafe fn from_f32(v: f32) -> Self::Unit {
+ _mm256_set1_ps(v)
+ }
+
+ #[cfg(target_feature = "f16c")]
+ unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
+ _mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
+ }
+
+ #[cfg(not(target_feature = "f16c"))]
+ unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
+ let mut tmp = [0.0f32; 8];
+ for i in 0..8 {
+ tmp[i] = (*mem_addr.add(i)).to_f32();
+ }
+ _mm256_loadu_ps(tmp.as_ptr())
+ }
+
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
+ _mm256_add_ps(a, b)
+ }
+
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
+ _mm256_add_ps(_mm256_mul_ps(b, c), a)
+ }
+
+ #[cfg(target_feature = "f16c")]
+ unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
+ _mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
+ }
+
+ #[cfg(not(target_feature = "f16c"))]
+ unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
+ let mut tmp = [0.0f32; 8];
+ _mm256_storeu_ps(tmp.as_mut_ptr(), a);
+ for i in 0..8 {
+ *mem_addr.add(i) = bf16::from_f32(tmp[i]);
+ }
+ }
+
+ unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
+ let mut offset = ARR >> 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ offset >>= 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ offset >>= 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
+ let t1 = _mm_hadd_ps(t0, t0);
+ *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
+ }
+}
diff --git a/candle-core/src/cpu/erf.rs b/candle-core/src/cpu/erf.rs
index ca6be53fd4..4d05728b0d 100644
--- a/candle-core/src/cpu/erf.rs
+++ b/candle-core/src/cpu/erf.rs
@@ -7,7 +7,7 @@ mod evaluate {
//! Provides functions that don't have a numerical solution and must
//! be solved computationally (e.g. evaluation of a polynomial)
- /// evaluates a polynomial at `z` where `coeff` are the coeffecients
+ /// evaluates a polynomial at `z` where `coeff` are the coefficients
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
/// coeffecient
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
@@ -32,18 +32,12 @@ mod evaluate {
use std::f64;
/// `erf` calculates the error function at `x`.
-pub fn erf(x: f64) -> f64 {
- if x.is_nan() {
- f64::NAN
- } else if x >= 0.0 && x.is_infinite() {
- 1.0
- } else if x <= 0.0 && x.is_infinite() {
- -1.0
- } else if x == 0. {
- 0.0
- } else {
- erf_impl(x, false)
- }
+pub fn erf_f64(x: f64) -> f64 {
+ libm::erf(x)
+}
+
+pub fn erf_f32(x: f32) -> f32 {
+ libm::erff(x)
}
/// `erf_inv` calculates the inverse error function
@@ -64,16 +58,12 @@ pub fn erf_inv(x: f64) -> f64 {
/// `erfc` calculates the complementary error function
/// at `x`.
-pub fn erfc(x: f64) -> f64 {
- if x.is_nan() {
- f64::NAN
- } else if x == f64::INFINITY {
- 0.0
- } else if x == f64::NEG_INFINITY {
- 2.0
- } else {
- erf_impl(x, true)
- }
+pub fn erfc_f64(x: f64) -> f64 {
+ libm::erfc(x)
+}
+
+pub fn erfc_f32(x: f32) -> f32 {
+ libm::erfcf(x)
}
/// `erfc_inv` calculates the complementary inverse
@@ -90,319 +80,6 @@ pub fn erfc_inv(x: f64) -> f64 {
}
}
-// **********************************************************
-// ********** Coefficients for erf_impl polynomial **********
-// **********************************************************
-
-/// Polynomial coefficients for a numerator of `erf_impl`
-/// in the interval [1e-10, 0.5].
-const ERF_IMPL_AN: &[f64] = &[
- 0.00337916709551257388990745,
- -0.00073695653048167948530905,
- -0.374732337392919607868241,
- 0.0817442448733587196071743,
- -0.0421089319936548595203468,
- 0.0070165709512095756344528,
- -0.00495091255982435110337458,
- 0.000871646599037922480317225,
-];
-
-/// Polynomial coefficients for a denominator of `erf_impl`
-/// in the interval [1e-10, 0.5]
-const ERF_IMPL_AD: &[f64] = &[
- 1.0,
- -0.218088218087924645390535,
- 0.412542972725442099083918,
- -0.0841891147873106755410271,
- 0.0655338856400241519690695,
- -0.0120019604454941768171266,
- 0.00408165558926174048329689,
- -0.000615900721557769691924509,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [0.5, 0.75].
-const ERF_IMPL_BN: &[f64] = &[
- -0.0361790390718262471360258,
- 0.292251883444882683221149,
- 0.281447041797604512774415,
- 0.125610208862766947294894,
- 0.0274135028268930549240776,
- 0.00250839672168065762786937,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [0.5, 0.75].
-const ERF_IMPL_BD: &[f64] = &[
- 1.0,
- 1.8545005897903486499845,
- 1.43575803037831418074962,
- 0.582827658753036572454135,
- 0.124810476932949746447682,
- 0.0113724176546353285778481,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [0.75, 1.25].
-const ERF_IMPL_CN: &[f64] = &[
- -0.0397876892611136856954425,
- 0.153165212467878293257683,
- 0.191260295600936245503129,
- 0.10276327061989304213645,
- 0.029637090615738836726027,
- 0.0046093486780275489468812,
- 0.000307607820348680180548455,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [0.75, 1.25].
-const ERF_IMPL_CD: &[f64] = &[
- 1.0,
- 1.95520072987627704987886,
- 1.64762317199384860109595,
- 0.768238607022126250082483,
- 0.209793185936509782784315,
- 0.0319569316899913392596356,
- 0.00213363160895785378615014,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [1.25, 2.25].
-const ERF_IMPL_DN: &[f64] = &[
- -0.0300838560557949717328341,
- 0.0538578829844454508530552,
- 0.0726211541651914182692959,
- 0.0367628469888049348429018,
- 0.00964629015572527529605267,
- 0.00133453480075291076745275,
- 0.778087599782504251917881e-4,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [1.25, 2.25].
-const ERF_IMPL_DD: &[f64] = &[
- 1.0,
- 1.75967098147167528287343,
- 1.32883571437961120556307,
- 0.552528596508757581287907,
- 0.133793056941332861912279,
- 0.0179509645176280768640766,
- 0.00104712440019937356634038,
- -0.106640381820357337177643e-7,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [2.25, 3.5].
-const ERF_IMPL_EN: &[f64] = &[
- -0.0117907570137227847827732,
- 0.014262132090538809896674,
- 0.0202234435902960820020765,
- 0.00930668299990432009042239,
- 0.00213357802422065994322516,
- 0.00025022987386460102395382,
- 0.120534912219588189822126e-4,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [2.25, 3.5].
-const ERF_IMPL_ED: &[f64] = &[
- 1.0,
- 1.50376225203620482047419,
- 0.965397786204462896346934,
- 0.339265230476796681555511,
- 0.0689740649541569716897427,
- 0.00771060262491768307365526,
- 0.000371421101531069302990367,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [3.5, 5.25].
-const ERF_IMPL_FN: &[f64] = &[
- -0.00546954795538729307482955,
- 0.00404190278731707110245394,
- 0.0054963369553161170521356,
- 0.00212616472603945399437862,
- 0.000394984014495083900689956,
- 0.365565477064442377259271e-4,
- 0.135485897109932323253786e-5,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [3.5, 5.25].
-const ERF_IMPL_FD: &[f64] = &[
- 1.0,
- 1.21019697773630784832251,
- 0.620914668221143886601045,
- 0.173038430661142762569515,
- 0.0276550813773432047594539,
- 0.00240625974424309709745382,
- 0.891811817251336577241006e-4,
- -0.465528836283382684461025e-11,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [5.25, 8].
-const ERF_IMPL_GN: &[f64] = &[
- -0.00270722535905778347999196,
- 0.0013187563425029400461378,
- 0.00119925933261002333923989,
- 0.00027849619811344664248235,
- 0.267822988218331849989363e-4,
- 0.923043672315028197865066e-6,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [5.25, 8].
-const ERF_IMPL_GD: &[f64] = &[
- 1.0,
- 0.814632808543141591118279,
- 0.268901665856299542168425,
- 0.0449877216103041118694989,
- 0.00381759663320248459168994,
- 0.000131571897888596914350697,
- 0.404815359675764138445257e-11,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [8, 11.5].
-const ERF_IMPL_HN: &[f64] = &[
- -0.00109946720691742196814323,
- 0.000406425442750422675169153,
- 0.000274499489416900707787024,
- 0.465293770646659383436343e-4,
- 0.320955425395767463401993e-5,
- 0.778286018145020892261936e-7,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [8, 11.5].
-const ERF_IMPL_HD: &[f64] = &[
- 1.0,
- 0.588173710611846046373373,
- 0.139363331289409746077541,
- 0.0166329340417083678763028,
- 0.00100023921310234908642639,
- 0.24254837521587225125068e-4,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [11.5, 17].
-const ERF_IMPL_IN: &[f64] = &[
- -0.00056907993601094962855594,
- 0.000169498540373762264416984,
- 0.518472354581100890120501e-4,
- 0.382819312231928859704678e-5,
- 0.824989931281894431781794e-7,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [11.5, 17].
-const ERF_IMPL_ID: &[f64] = &[
- 1.0,
- 0.339637250051139347430323,
- 0.043472647870310663055044,
- 0.00248549335224637114641629,
- 0.535633305337152900549536e-4,
- -0.117490944405459578783846e-12,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [17, 24].
-const ERF_IMPL_JN: &[f64] = &[
- -0.000241313599483991337479091,
- 0.574224975202501512365975e-4,
- 0.115998962927383778460557e-4,
- 0.581762134402593739370875e-6,
- 0.853971555085673614607418e-8,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [17, 24].
-const ERF_IMPL_JD: &[f64] = &[
- 1.0,
- 0.233044138299687841018015,
- 0.0204186940546440312625597,
- 0.000797185647564398289151125,
- 0.117019281670172327758019e-4,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [24, 38].
-const ERF_IMPL_KN: &[f64] = &[
- -0.000146674699277760365803642,
- 0.162666552112280519955647e-4,
- 0.269116248509165239294897e-5,
- 0.979584479468091935086972e-7,
- 0.101994647625723465722285e-8,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [24, 38].
-const ERF_IMPL_KD: &[f64] = &[
- 1.0,
- 0.165907812944847226546036,
- 0.0103361716191505884359634,
- 0.000286593026373868366935721,
- 0.298401570840900340874568e-5,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [38, 60].
-const ERF_IMPL_LN: &[f64] = &[
- -0.583905797629771786720406e-4,
- 0.412510325105496173512992e-5,
- 0.431790922420250949096906e-6,
- 0.993365155590013193345569e-8,
- 0.653480510020104699270084e-10,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [38, 60].
-const ERF_IMPL_LD: &[f64] = &[
- 1.0,
- 0.105077086072039915406159,
- 0.00414278428675475620830226,
- 0.726338754644523769144108e-4,
- 0.477818471047398785369849e-6,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [60, 85].
-const ERF_IMPL_MN: &[f64] = &[
- -0.196457797609229579459841e-4,
- 0.157243887666800692441195e-5,
- 0.543902511192700878690335e-7,
- 0.317472492369117710852685e-9,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [60, 85].
-const ERF_IMPL_MD: &[f64] = &[
- 1.0,
- 0.052803989240957632204885,
- 0.000926876069151753290378112,
- 0.541011723226630257077328e-5,
- 0.535093845803642394908747e-15,
-];
-
-/// Polynomial coefficients for a numerator in `erf_impl`
-/// in the interval [85, 110].
-const ERF_IMPL_NN: &[f64] = &[
- -0.789224703978722689089794e-5,
- 0.622088451660986955124162e-6,
- 0.145728445676882396797184e-7,
- 0.603715505542715364529243e-10,
-];
-
-/// Polynomial coefficients for a denominator in `erf_impl`
-/// in the interval [85, 110].
-const ERF_IMPL_ND: &[f64] = &[
- 1.0,
- 0.0375328846356293715248719,
- 0.000467919535974625308126054,
- 0.193847039275845656900547e-5,
-];
-
// **********************************************************
// ********** Coefficients for erf_inv_impl polynomial ******
// **********************************************************
@@ -594,121 +271,6 @@ const ERF_INV_IMPL_GD: &[f64] = &[
0.231558608310259605225e-11,
];
-/// `erf_impl` computes the error function at `z`.
-/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
-fn erf_impl(z: f64, inv: bool) -> f64 {
- if z < 0.0 {
- if !inv {
- return -erf_impl(-z, false);
- }
- if z < -0.5 {
- return 2.0 - erf_impl(-z, true);
- }
- return 1.0 + erf_impl(-z, false);
- }
-
- let result = if z < 0.5 {
- if z < 1e-10 {
- z * 1.125 + z * 0.003379167095512573896158903121545171688
- } else {
- z * 1.125
- + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
- }
- } else if z < 110.0 {
- let (r, b) = if z < 0.75 {
- (
- evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
- / evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
- 0.3440242112,
- )
- } else if z < 1.25 {
- (
- evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
- / evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
- 0.419990927,
- )
- } else if z < 2.25 {
- (
- evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
- / evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
- 0.4898625016,
- )
- } else if z < 3.5 {
- (
- evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
- / evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
- 0.5317370892,
- )
- } else if z < 5.25 {
- (
- evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
- / evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
- 0.5489973426,
- )
- } else if z < 8.0 {
- (
- evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
- / evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
- 0.5571740866,
- )
- } else if z < 11.5 {
- (
- evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
- / evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
- 0.5609807968,
- )
- } else if z < 17.0 {
- (
- evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
- / evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
- 0.5626493692,
- )
- } else if z < 24.0 {
- (
- evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
- / evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
- 0.5634598136,
- )
- } else if z < 38.0 {
- (
- evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
- / evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
- 0.5638477802,
- )
- } else if z < 60.0 {
- (
- evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
- / evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
- 0.5640528202,
- )
- } else if z < 85.0 {
- (
- evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
- / evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
- 0.5641309023,
- )
- } else {
- (
- evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
- / evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
- 0.5641584396,
- )
- };
- let g = (-z * z).exp() / z;
- g * b + g * r
- } else {
- 0.0
- };
-
- if inv && z >= 0.5 {
- result
- } else if z >= 0.5 || inv {
- 1.0 - result
- } else {
- result
- }
-}
-
// `erf_inv_impl` computes the inverse error function where
// `p`,`q`, and `s` are the first, second, and third intermediate
// parameters respectively
diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs
index 527646d62b..bca76adcc8 100644
--- a/candle-core/src/cpu/kernels.rs
+++ b/candle-core/src/cpu/kernels.rs
@@ -121,6 +121,13 @@ impl VecOps for half::bf16 {
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
+
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ let mut res_f32 = 0f32;
+ super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
+ *res = half::bf16::from_f32(res_f32);
+ }
}
impl VecOps for u8 {
#[inline(always)]
@@ -144,6 +151,28 @@ impl VecOps for u32 {
::max(self, other)
}
}
+impl VecOps for i16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
+impl VecOps for i32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
@@ -156,6 +185,18 @@ impl VecOps for i64 {
}
}
+impl VecOps for float8::F8E4M3 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs
index e7d8b6906f..c4864b7a81 100644
--- a/candle-core/src/cpu/mod.rs
+++ b/candle-core/src/cpu/mod.rs
@@ -1,3 +1,5 @@
+//! Traits and methods for CPU-backed Tensors
+
pub mod erf;
pub mod kernels;
@@ -36,14 +38,33 @@ trait CpuF16 {
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
}
-use half::f16;
+
+#[allow(unused)]
+trait CpuBF16 {
+ type Unit;
+ type Array;
+ const STEP: usize;
+ const EPR: usize;
+
+ fn n() -> usize;
+ unsafe fn zero() -> Self::Unit;
+ unsafe fn zero_array() -> Self::Array;
+ unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
+ unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
+ unsafe fn from_f32(v: f32) -> Self::Unit;
+ unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
+}
+
+use half::{bf16, f16};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
-#[cfg(target_feature = "avx")]
+#[cfg(target_feature = "avx2")]
pub mod avx;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
-#[cfg(target_feature = "avx")]
-pub use avx::{CurrentCpu, CurrentCpuF16};
+#[cfg(target_feature = "avx2")]
+pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};
#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
@@ -61,7 +82,7 @@ pub use neon::CurrentCpu;
#[cfg(any(
target_feature = "neon",
- target_feature = "avx",
+ target_feature = "avx2",
target_feature = "simd128"
))]
#[inline(always)]
@@ -91,7 +112,7 @@ pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f
#[cfg(not(any(
target_feature = "neon",
- target_feature = "avx",
+ target_feature = "avx2",
target_feature = "simd128"
)))]
#[inline(always)]
@@ -104,7 +125,7 @@ pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f
#[cfg(any(
target_feature = "neon",
- target_feature = "avx",
+ target_feature = "avx2",
target_feature = "simd128"
))]
#[inline(always)]
@@ -131,7 +152,7 @@ pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
#[cfg(not(any(
target_feature = "neon",
- target_feature = "avx",
+ target_feature = "avx2",
target_feature = "simd128"
)))]
#[inline(always)]
@@ -142,7 +163,7 @@ pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
}
}
-#[cfg(target_feature = "avx")]
+#[cfg(target_feature = "avx2")]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
let mut sumf = 0.0f32;
@@ -170,7 +191,35 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
*c = sumf;
}
-#[cfg(not(target_feature = "avx"))]
+#[cfg(target_feature = "avx2")]
+#[inline(always)]
+pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
+ let mut sumf = 0.0f32;
+ let np = k & !(CurrentCpuBF16::STEP - 1);
+
+ let mut sum = CurrentCpuBF16::zero_array();
+ let mut ax = CurrentCpuBF16::zero_array();
+ let mut ay = CurrentCpuBF16::zero_array();
+
+ for i in (0..np).step_by(CurrentCpuBF16::STEP) {
+ for j in 0..CurrentCpuBF16::n() {
+ ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
+ ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));
+
+ sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ CurrentCpuBF16::vec_reduce(sum, &mut sumf);
+
+ // leftovers
+ for i in np..k {
+ sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
+ }
+ *c = sumf;
+}
+
+#[cfg(not(target_feature = "avx2"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
// leftovers
@@ -180,3 +229,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
}
*c = sum;
}
+
+#[cfg(not(target_feature = "avx2"))]
+#[inline(always)]
+pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
+ // leftovers
+ let mut sum = 0.0;
+ for i in 0..k {
+ sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
+ }
+ *c = sum;
+}
diff --git a/candle-core/src/cpu_backend/conv2d.rs b/candle-core/src/cpu_backend/conv2d.rs
new file mode 100644
index 0000000000..70c5ea75fa
--- /dev/null
+++ b/candle-core/src/cpu_backend/conv2d.rs
@@ -0,0 +1,432 @@
+use std::borrow::Cow;
+
+use rayon::iter::{IntoParallelIterator, ParallelIterator};
+
+use crate::{
+ conv::ParamsConv2D,
+ cpu_backend::{copy_strided_src_, Im2Col, Map1, Map2, MatMul},
+ shape::dims4,
+ Layout, Result, WithDType,
+};
+
+pub(super) struct Conv2D<'a>(pub(super) &'a crate::conv::ParamsConv2D);
+
+#[allow(dead_code)]
+enum Conv2dImpl {
+ TiledIm2Col,
+ FullIm2Col,
+ Direct,
+}
+
+const DEFAULT_CONV2D_IMPL: Conv2dImpl = Conv2dImpl::TiledIm2Col;
+
+impl Map2 for Conv2D<'_> {
+ const OP: &'static str = "conv2d";
+ fn f(
+ &self,
+ inp: &[T],
+ inp_l: &Layout,
+ k: &[T],
+ k_l: &Layout,
+ ) -> Result> {
+ let p = self.0;
+
+ // Specialization: pick the best algorithm based on parameters.
+ // 1x1 convolutions with stride=1, padding=0, dilation=1
+ if p.k_h == 1 && p.k_w == 1 && p.stride == 1 && p.padding == 0 && p.dilation == 1 {
+ return conv2d_1x1(p, inp, inp_l, k, k_l);
+ } else if p.k_h == 1 && p.k_w == 1 {
+ // Other 1x1 convolutions for now are assumed faster with full im2col,
+ // although with large enough input size, tiled will start beating it.
+ return conv2d_im2col_gemm(p, inp, inp_l, k, k_l);
+ }
+ // TODO other cases
+
+ // No fast path, fallback to default general impl.
+ match DEFAULT_CONV2D_IMPL {
+ Conv2dImpl::TiledIm2Col => conv2d_tiled(p, inp, inp_l, k, k_l),
+ Conv2dImpl::Direct => conv2d_direct(p, inp, inp_l, k, k_l),
+ Conv2dImpl::FullIm2Col => conv2d_im2col_gemm(p, inp, inp_l, k, k_l),
+ }
+ }
+}
+
+/// Fast kernel for 1x1 convolutions with stride=1, padding=0, dilation=1
+/// These are just matrix multiplications: [c_out, c_in] @ [c_in, b*h*w] -> [c_out, b*h*w].
+fn conv2d_1x1(
+ p: &ParamsConv2D,
+ inp: &[T],
+ inp_l: &Layout,
+ k: &[T],
+ k_l: &Layout,
+) -> Result> {
+ let inp = &inp[inp_l.start_offset()..];
+ let inp_stride = inp_l.stride();
+ let (inp_s0, inp_s1, inp_s2, inp_s3) =
+ (inp_stride[0], inp_stride[1], inp_stride[2], inp_stride[3]);
+ let k = &k[k_l.start_offset()..];
+ let k_stride = k_l.stride();
+ let (k_s0, k_s1) = (k_stride[0], k_stride[1]);
+ let (out_h, out_w) = (p.out_h(), p.out_w());
+
+ let spatial_size = out_h * out_w;
+ let dst = vec![T::zero(); p.b_size * p.c_out * spatial_size];
+ let k_reshaped: Cow<[T]> = if k_s0 == p.c_in && k_s1 == 1 {
+ // Already contiguous, use slice directly
+ Cow::Borrowed(&k[..p.c_out * p.c_in])
+ } else {
+ // Reshape kernel to [c_out, c_in]
+ let mut k_reshaped = Vec::with_capacity(p.c_out * p.c_in);
+ (0..p.c_out).for_each(|c_out_idx| {
+ (0..p.c_in).for_each(|c_in_idx| {
+ let k_idx = c_out_idx * k_s0 + c_in_idx * k_s1;
+ k_reshaped.push(k[k_idx]);
+ });
+ });
+ Cow::Owned(k_reshaped)
+ };
+ let k_layout = Layout::contiguous((p.c_out, p.c_in));
+
+ // Process each batch
+ (0..p.b_size).into_par_iter().try_for_each(|b_idx| {
+ // Reshape input to [c_in, h*w] for this batch
+ let mut inp_reshaped = Vec::with_capacity(p.c_in * spatial_size);
+ for c_in_idx in 0..p.c_in {
+ for h_idx in 0..p.i_h {
+ for w_idx in 0..p.i_w {
+ let inp_idx =
+ b_idx * inp_s0 + c_in_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
+ inp_reshaped.push(inp[inp_idx]);
+ }
+ }
+ }
+ let inp_layout = Layout::contiguous((p.c_in, spatial_size));
+
+ // Perform matmul: [c_out, c_in] @ [c_in, spatial_size] -> [c_out, spatial_size]
+ let matmul = MatMul((1, p.c_out, spatial_size, p.c_in));
+ let result = matmul.f(&k_reshaped, &k_layout, &inp_reshaped, &inp_layout)?;
+
+ // Copy result to output
+ let out_offset = b_idx * p.c_out * spatial_size;
+ for (i, r) in result.iter().enumerate() {
+ unsafe {
+ let ptr = dst.as_ptr().add(out_offset + i) as *mut T;
+ *ptr = *r;
+ }
+ }
+ Ok::<(), crate::Error>(())
+ })?;
+
+ Ok(dst)
+}
+
+/// General tiled convolution implementation using gemm.
+///
+/// Similar to full im2col, but instead of materializing the full matrix, we process input/output in tiles, in parallel.
+fn conv2d_tiled(
+ p: &ParamsConv2D,
+ inp: &[T],
+ inp_l: &Layout,
+ k: &[T],
+ k_l: &Layout,
+) -> Result> {
+ let inp = &inp[inp_l.start_offset()..];
+ let (inp_s0, inp_s1, inp_s2, inp_s3) = dims4(inp_l.stride())?;
+ let k = &k[k_l.start_offset()..];
+ let (k_s0, k_s1, k_s2, k_s3) = dims4(k_l.stride())?;
+ let (out_h, out_w) = (p.out_h(), p.out_w());
+
+ // Output shape: [b_size, c_out, out_h, out_w].
+ let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
+
+ // Make contiguous input copy if needed.
+ let cont_s0 = p.i_h * p.i_w * p.c_in;
+ let cont_s1 = p.i_w * p.c_in;
+ let cont_s2 = p.c_in;
+ let layout_is_valid = inp_l.stride() == [cont_s0, cont_s1, cont_s2, 1];
+ let inp_cont: Cow<[T]> = if layout_is_valid {
+ Cow::Borrowed(inp)
+ } else {
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
+ for b_idx in 0..p.b_size {
+ for h_idx in 0..p.i_h {
+ for w_idx in 0..p.i_w {
+ for c_idx in 0..p.c_in {
+ let src_idx =
+ b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
+ let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
+ inp_cont[dst_idx] = inp[src_idx]
+ }
+ }
+ }
+ }
+ Cow::Owned(inp_cont)
+ };
+
+ // shape of k: [c_out, c_in, k_h, k_w]
+ // strides of k: [k_s0, k_s1, k_s2, k_s3]
+ // For matmul, we need flattened k in shape [c_out, k_h * k_w * c_in]
+ // with stride [k_h * k_w * c_in, 1]
+ let k_size = p.c_in * p.k_h * p.k_w;
+ let mut k_flat = Vec::with_capacity(p.c_out * k_size);
+ for dst_c_idx in 0..p.c_out {
+ for kh in 0..p.k_h {
+ for kw in 0..p.k_w {
+ for c_in_idx in 0..p.c_in {
+ let k_idx = dst_c_idx * k_s0 + c_in_idx * k_s1 + kh * k_s2 + kw * k_s3;
+ k_flat.push(k[k_idx]);
+ }
+ }
+ }
+ }
+ // k_layout: [c_out, k_size] with stride [k_size, 1]
+ let k_layout = Layout::contiguous((p.c_out, k_size));
+
+ // TILE_SIZE is number of output pixels (out_h * out_w) per tile.
+ // Higher tile size can be faster due to better usage of gemm,
+ // but lower tile sizes enable bigger parallelism across tiles.
+ // This parameter is impactful and may be dynamic or even runtime tunable in the future.
+ const TILE_SIZE: usize = 512;
+
+ let total_out_pixels = out_h * out_w;
+
+ // Process batches and tiles in parallel using rayon.
+ (0..p.b_size).into_par_iter().try_for_each(|b_idx| {
+ let inp_offset = b_idx * cont_s0;
+ let out_batch_offset = b_idx * (p.c_out * out_h * out_w);
+
+ let num_tiles = total_out_pixels.div_ceil(TILE_SIZE);
+ (0..num_tiles).into_par_iter().try_for_each(|tile_idx| {
+ // Determine actual tile size (may be smaller at the end) {
+ let tile_start = tile_idx * TILE_SIZE;
+ let tile_end = (tile_start + TILE_SIZE).min(total_out_pixels);
+ let tile_size = tile_end - tile_start;
+
+ // Precompute output coordinates.
+ // Used in both im2col extraction and writing output.
+ let out_coords: Vec<_> = (tile_start..tile_end)
+ .map(|idx| (idx / out_w, idx % out_w))
+ .collect();
+
+ // Build im2col tile: [k_size, tile_size]
+ // This represents the input patches needed for this tile of outputs
+ let mut col_tile = vec![T::zero(); k_size * tile_size];
+
+ for (tile_idx, (out_y, out_x)) in out_coords.iter().enumerate() {
+ // Extract the im2col patch for this output position
+ for c_in in 0..p.c_in {
+ let mut patch_offset = c_in;
+ for kh in 0..p.k_h {
+ let in_y =
+ (out_y * p.stride + kh * p.dilation) as isize - p.padding as isize;
+ if in_y < 0 || in_y >= p.i_h as isize {
+ // Padding: already zero
+ patch_offset += p.c_in * p.k_w;
+ continue;
+ }
+ for kw in 0..p.k_w {
+ let in_x =
+ (out_x * p.stride + kw * p.dilation) as isize - p.padding as isize;
+
+ if in_x >= 0 && in_x < p.i_w as isize {
+ let in_y = in_y as usize;
+ let in_x = in_x as usize;
+ let inp_idx = inp_offset + in_y * cont_s1 + in_x * cont_s2 + c_in;
+ let col_idx = patch_offset * tile_size + tile_idx;
+ col_tile[col_idx] = inp_cont[inp_idx];
+ }
+ // Move to next position (skip c_in channels)
+ patch_offset += p.c_in;
+ }
+ }
+ }
+ }
+
+ // Now perform matmul: k_cache [c_out, k_size] @ col_tile [k_size, tile_size]
+ let matmul = MatMul((1, p.c_out, tile_size, k_size));
+
+ // Layouts for matmul
+ // k_flat layout: [c_out, k_size] with stride [k_size, 1]
+ // col_tile layout: [k_size, tile_size] with stride [tile_size, 1]
+ let col_layout = Layout::contiguous((k_size, tile_size));
+
+ // Perform matmul
+ let result = matmul.f(&k_flat, &k_layout, &col_tile, &col_layout)?;
+
+ // Copy results to output: result is [c_out, tile_size]
+ for (tile_idx, (out_y, out_x)) in out_coords.iter().enumerate() {
+ let dst_base = out_batch_offset + out_y * out_w + out_x;
+
+ for c_out_idx in 0..p.c_out {
+ let dst_idx = dst_base + c_out_idx * (out_h * out_w);
+ let result_idx = c_out_idx * tile_size + tile_idx;
+ // SAFETY: Each batch processes a distinct region of the output buffer.
+ // Within each batch, tiles process non-overlapping output positions.
+ // Therefore, no two threads will write to the same dst_idx.
+ unsafe {
+ let ptr = dst.as_ptr().add(dst_idx) as *mut T;
+ *ptr = result[result_idx];
+ }
+ }
+ }
+ Ok::<(), crate::Error>(())
+ })
+ })?;
+
+ Ok(dst)
+}
+
+/// General direct convolution impl. Decently fast for small inputs and kernels, but loses to full/tiled gemm.
+fn conv2d_direct(
+ p: &ParamsConv2D,
+ inp: &[T],
+ inp_l: &Layout,
+ k: &[T],
+ k_l: &Layout,
+) -> Result> {
+ let inp = &inp[inp_l.start_offset()..];
+ let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
+ let k = &k[k_l.start_offset()..];
+ let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
+ let (out_h, out_w) = (p.out_h(), p.out_w());
+
+ // Output shape: [b_size, c_out, out_h, out_w].
+ let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
+
+ // Make contiguous input copy if needed.
+ let cont_s0 = p.i_h * p.i_w * p.c_in;
+ let cont_s1 = p.i_w * p.c_in;
+ let cont_s2 = p.c_in;
+ let layout_is_valid = inp_l.stride() == [cont_s0, cont_s1, cont_s2, 1];
+ let inp_cont: Cow<[T]> = if layout_is_valid {
+ Cow::Borrowed(inp)
+ } else {
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
+ for b_idx in 0..p.b_size {
+ for h_idx in 0..p.i_h {
+ for w_idx in 0..p.i_w {
+ for c_idx in 0..p.c_in {
+ let src_idx =
+ b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
+ let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
+ inp_cont[dst_idx] = inp[src_idx]
+ }
+ }
+ }
+ }
+ Cow::Owned(inp_cont)
+ };
+ let inp_cont_len = inp_cont.len();
+
+ let k_cache: Vec> = (0..p.c_out)
+ .map(|dst_c_idx| {
+ (0..p.k_h * p.k_w)
+ .flat_map(|kw_kh| {
+ let offset_h = kw_kh / p.k_w;
+ let offset_w = kw_kh % p.k_w;
+ (0..p.c_in).map(move |c_in_idx| {
+ k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset_h * k_s2 + offset_w * k_s3]
+ })
+ })
+ .collect()
+ })
+ .collect();
+
+ for b_idx in 0..p.b_size {
+ for offset_h in 0..p.k_h {
+ for offset_w in 0..p.k_w {
+ let k_offset = offset_h * p.k_w + offset_w;
+
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
+ let k_cont = &k_cache[dst_c_idx][k_offset * p.c_in..(k_offset + 1) * p.c_in];
+ let base_dst_idx = dst_c_idx * out_w * out_h;
+ let batch_dst_idx = base_dst_idx + b_idx * p.c_out * out_h * out_w;
+ let batch_src_idx = b_idx * cont_s0;
+
+ for dst_h in 0..out_h {
+ let src_h = p.stride * dst_h + offset_h * p.dilation;
+ if src_h < p.padding || src_h >= p.i_h + p.padding {
+ continue;
+ }
+ let src_h = src_h - p.padding;
+ let h_dst_idx = batch_dst_idx + dst_h * out_w;
+ let h_src_idx = batch_src_idx + src_h * cont_s1;
+
+ for dst_w in 0..out_w {
+ let src_w = p.stride * dst_w + offset_w * p.dilation;
+ if src_w < p.padding || src_w >= p.i_w + p.padding {
+ continue;
+ }
+ let src_w = src_w - p.padding;
+ let dst_idx = h_dst_idx + dst_w;
+ let inp_idx_1 = h_src_idx + src_w * cont_s2;
+ let inp_idx_2 = (inp_idx_1 + p.c_in).min(inp_cont_len);
+ let inp_cont = &inp_cont[inp_idx_1..inp_idx_2];
+ let mut d = T::zero();
+ unsafe {
+ T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in);
+ let ptr = dst.as_ptr().add(dst_idx) as *mut T;
+ *ptr += d;
+ }
+ }
+ }
+ });
+ }
+ }
+ }
+
+ Ok(dst)
+}
+
+#[allow(clippy::uninit_vec)]
+fn alloc_uninit_vec(size: usize) -> Vec {
+ let mut v = Vec::with_capacity(size);
+ unsafe { v.set_len(size) };
+ v
+}
+
+/// Full im2col + gemm convolution implementation.
+///
+/// For large inputs im2col and copy_strided_src for output gets expensive.
+fn conv2d_im2col_gemm(
+ p: &ParamsConv2D,
+ inp: &[T],
+ inp_l: &Layout,
+ kernel: &[T],
+ kernel_l: &Layout,
+) -> Result> {
+ let op = Im2Col {
+ h_k: p.k_h,
+ w_k: p.k_w,
+ padding: p.padding,
+ stride: p.stride,
+ dilation: p.dilation,
+ };
+ let col = op.f(inp, inp_l)?;
+ let b = p.b_size;
+ let n = p.c_out;
+ let (h_out, w_out) = (p.out_h(), p.out_w());
+ let k = op.h_k * op.w_k * p.c_in;
+ let m = h_out * w_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res: Vec = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ MatMul((b, m, n, k)).f(&col, &col_l, kernel, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = alloc_uninit_vec(kernel_l.shape().elem_count());
+ copy_strided_src_(kernel, &mut kernel_c, 0, kernel_l);
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ MatMul((b, m, n, k)).f(&col, &col_l, &kernel_c, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, h_out, w_out, p.c_out))
+ .transpose(1, 2)?
+ .transpose(1, 3)?;
+ let mut res_t = alloc_uninit_vec(res_l.shape().elem_count());
+ copy_strided_src_(&res, &mut res_t, 0, &res_l);
+ Ok(res_t)
+}
diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs
index 58773c8020..afb93024ac 100644
--- a/candle-core/src/cpu_backend/mod.rs
+++ b/candle-core/src/cpu_backend/mod.rs
@@ -1,17 +1,20 @@
+//! Implementation of Backend Fns for CPU
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
+use float8::F8E4M3;
use half::{bf16, f16};
use rayon::prelude::*;
mod utils;
pub use utils::{
- binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
+ binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
};
+mod conv2d;
+use conv2d::Conv2D;
const USE_IM2COL_CONV1D: bool = true;
const USE_COL2IM_CONV1D_TR: bool = true;
-const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
@@ -19,22 +22,38 @@ const USE_IM2COL_CONV2D: bool = true;
pub enum CpuStorage {
U8(Vec),
U32(Vec),
+ I16(Vec),
+ I32(Vec),
I64(Vec),
BF16(Vec),
F16(Vec),
F32(Vec),
F64(Vec),
+ F8E4M3(Vec),
+ // Dummy types that store raw bytes
+ F6E2M3(Vec),
+ F6E3M2(Vec),
+ F4(Vec),
+ F8E8M0(Vec),
}
#[derive(Debug, Clone)]
pub enum CpuStorageRef<'a> {
U8(&'a [u8]),
U32(&'a [u32]),
+ I16(&'a [i16]),
+ I32(&'a [i32]),
I64(&'a [i64]),
BF16(&'a [bf16]),
F16(&'a [f16]),
F32(&'a [f32]),
F64(&'a [f64]),
+ F8E4M3(&'a [F8E4M3]),
+ // Dummy types that store raw bytes
+ F6E2M3(&'a [u8]),
+ F6E3M2(&'a [u8]),
+ F4(&'a [u8]),
+ F8E8M0(&'a [u8]),
}
#[derive(Debug, Clone)]
@@ -65,7 +84,7 @@ impl Map2U8 for Cmp {
struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
-impl<'a, I: IntDType> Map2 for WCond<'a, I> {
+impl Map2 for WCond<'_, I> {
const OP: &'static str = "where";
#[inline(always)]
fn f(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result> {
@@ -215,7 +234,7 @@ struct ReduceSum<'a> {
reduce_dims_and_stride: Vec<(usize, usize)>,
}
-impl<'a> ReduceSum<'a> {
+impl ReduceSum<'_> {
#[inline(always)]
fn fold_impl(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result>
where
@@ -280,7 +299,7 @@ impl<'a> ReduceSum<'a> {
}
}
-impl<'a> Map1 for ReduceSum<'a> {
+impl Map1 for ReduceSum<'_> {
#[inline(always)]
fn f(&self, src: &[T], src_l: &Layout) -> Result> {
self.fold_impl(src, src_l, T::zero())
@@ -447,13 +466,132 @@ impl Map1 for UpsampleNearest2D {
}
}
+struct UpsampleBilinear2D {
+ target_h: usize,
+ target_w: usize,
+ align_corners: bool,
+ scale_h_factor: Option,
+ scale_w_factor: Option,
+}
+
+impl Map1 for UpsampleBilinear2D {
+ fn f(&self, src: &[T], layout: &Layout) -> Result> {
+ let (batch, channels, height_in, width_in) = layout.shape().dims4()?;
+ let height_out = self.target_h;
+ let width_out = self.target_w;
+
+ // Early return for identity case
+ if height_in == height_out && width_in == width_out {
+ return Ok(src.to_vec());
+ }
+
+ let stride = layout.stride();
+ let src_offset = layout.start_offset();
+
+ // Calculate scale factors following PyTorch's area_pixel_compute_scale logic
+ let scale_h = if self.align_corners {
+ if height_out > 1 {
+ (height_in - 1) as f64 / (height_out - 1) as f64
+ } else {
+ 0.0
+ }
+ } else {
+ // PyTorch's compute_scales_value logic:
+ // If scale_factor was provided, use 1.0 / scale_factor
+ // Otherwise, use input_size / output_size
+ if let Some(scale_factor) = self.scale_h_factor {
+ 1.0 / scale_factor
+ } else {
+ height_in as f64 / height_out as f64
+ }
+ };
+
+ let scale_w = if self.align_corners {
+ if width_out > 1 {
+ (width_in - 1) as f64 / (width_out - 1) as f64
+ } else {
+ 0.0
+ }
+ } else if let Some(scale_factor) = self.scale_w_factor {
+ 1.0 / scale_factor
+ } else {
+ width_in as f64 / width_out as f64
+ };
+
+ // Precompute indices and weights for height
+ let mut h_indices = Vec::with_capacity(height_out);
+ for h_out in 0..height_out {
+ let src_h = if self.align_corners {
+ scale_h * h_out as f64
+ } else {
+ scale_h * (h_out as f64 + 0.5) - 0.5
+ };
+ let src_h_clamped = src_h.max(0.0);
+ let h0 = src_h_clamped.floor() as usize;
+ let h1 = (h0 + 1).min(height_in - 1);
+ let weight_h = (src_h_clamped - h0 as f64).clamp(0.0, 1.0);
+ h_indices.push((h0, h1, weight_h));
+ }
+
+ // Precompute indices and weights for width
+ let mut w_indices = Vec::with_capacity(width_out);
+ for w_out in 0..width_out {
+ let src_w = if self.align_corners {
+ scale_w * w_out as f64
+ } else {
+ scale_w * (w_out as f64 + 0.5) - 0.5
+ };
+ let src_w_clamped = src_w.max(0.0);
+ let w0 = src_w_clamped.floor() as usize;
+ let w1 = (w0 + 1).min(width_in - 1);
+ let weight_w = (src_w_clamped - w0 as f64).clamp(0.0, 1.0);
+ w_indices.push((w0, w1, weight_w));
+ }
+
+ // Allocate output
+ let mut dst = vec![T::zero(); batch * channels * height_out * width_out];
+
+ // Perform bilinear interpolation
+ for b in 0..batch {
+ for c in 0..channels {
+ let base_idx = src_offset + b * stride[0] + c * stride[1];
+ let dst_base = (b * channels + c) * height_out * width_out;
+
+ for (h_out, &(h0, h1, weight_h)) in h_indices.iter().enumerate() {
+ for (w_out, &(w0, w1, weight_w)) in w_indices.iter().enumerate() {
+ // Get four neighboring pixels
+ let idx_00 = base_idx + h0 * stride[2] + w0 * stride[3];
+ let idx_10 = base_idx + h0 * stride[2] + w1 * stride[3];
+ let idx_01 = base_idx + h1 * stride[2] + w0 * stride[3];
+ let idx_11 = base_idx + h1 * stride[2] + w1 * stride[3];
+
+ let v00 = src[idx_00].to_f64();
+ let v10 = src[idx_10].to_f64();
+ let v01 = src[idx_01].to_f64();
+ let v11 = src[idx_11].to_f64();
+
+ // Bilinear interpolation
+ let v_top = v00 * (1.0 - weight_w) + v10 * weight_w;
+ let v_bottom = v01 * (1.0 - weight_w) + v11 * weight_w;
+ let value = v_top * (1.0 - weight_h) + v_bottom * weight_h;
+
+ dst[dst_base + h_out * width_out + w_out] = T::from_f64(value);
+ }
+ }
+ }
+ }
+
+ Ok(dst)
+ }
+}
+
struct Gather<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
}
-impl<'a, I: IntDType> Map1 for Gather<'a, I> {
+impl Map1 for Gather<'_, I> {
fn f(&self, src: &[T], src_l: &Layout) -> Result> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
@@ -482,17 +620,22 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
let start_dst_idx = start_dst_idx + i * dst_right_len;
for right_i in 0..dst_right_len {
let dst_idx = start_dst_idx + right_i;
- let index = ids[dst_idx].as_usize();
- if index >= src_dim_len {
- Err(Error::InvalidIndex {
- index,
- size: src_dim_len,
- op: "gather",
+ let index = ids[dst_idx];
+ if index == I::max_value() {
+ dst[dst_idx] = T::zero();
+ } else {
+ let index = index.as_usize();
+ if index >= src_dim_len {
+ Err(Error::InvalidIndex {
+ index,
+ size: src_dim_len,
+ op: "gather",
+ }
+ .bt())?
}
- .bt())?
+ let src_idx = start_src_idx + index * src_right_len + right_i;
+ dst[dst_idx] = src[src_idx]
}
- let src_idx = start_src_idx + index * src_right_len + right_i;
- dst[dst_idx] = src[src_idx]
}
}
}
@@ -506,7 +649,7 @@ struct IndexSelect<'a, T: IntDType> {
dim: usize,
}
-impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
+impl Map1 for IndexSelect<'_, I> {
fn f(&self, src: &[T], layout: &Layout) -> Result> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
@@ -534,45 +677,89 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
- let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
- if index >= src_dim {
- Err(Error::InvalidIndex {
- index,
- size: src_dim,
- op: "index-select",
+ let start_dst_idx = start_dst_idx + i * right_len;
+ let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
+ if index == I::max_value() {
+ dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
+ } else {
+ let index = index.as_usize();
+ if index >= src_dim {
+ Err(Error::InvalidIndex {
+ index,
+ size: src_dim,
+ op: "index-select",
+ }
+ .bt())?
}
- .bt())?
+ let start_src_idx = start_src_idx + index * right_len;
+ dst[start_dst_idx..start_dst_idx + right_len]
+ .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
}
- let start_src_idx = start_src_idx + index * right_len;
- let start_dst_idx = start_dst_idx + i * right_len;
- dst[start_dst_idx..start_dst_idx + right_len]
- .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
}
}
Ok(dst)
}
}
-struct ScatterAdd<'a, I: IntDType> {
+trait ElemUpdate {
+ fn f(dst: &mut T, src: T);
+}
+
+struct Set;
+struct Add;
+
+impl ElemUpdate for Set {
+ fn f(dst: &mut T, src: T) {
+ *dst = src
+ }
+}
+
+impl ElemUpdate for Add {
+ fn f(dst: &mut T, src: T) {
+ *dst += src
+ }
+}
+
+struct Scatter<'a, I: IntDType, M: ElemUpdate> {
ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
+ _phantom: std::marker::PhantomData,
}
-impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
- const OP: &'static str = "scatter-add";
- fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> {
- let dst_len = l1.shape().elem_count();
- let mut dst = vec![T::zero(); dst_len];
- copy_strided_src_(v1, &mut dst, 0, l1);
+impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
+ fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
+ Self {
+ ids,
+ ids_l,
+ dim,
+ _phantom: Default::default(),
+ }
+ }
+}
+
+impl Map2InPlace for Scatter<'_, I, M> {
+ const OP: &'static str = "scatter";
+ fn f(
+ &self,
+ dst: &mut [T],
+ dst_l: &Layout,
+ src: &[T],
+ src_l: &Layout,
+ ) -> Result<()> {
+ let dst = match dst_l.contiguous_offsets() {
+ None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
+ Some((o1, o2)) => &mut dst[o1..o2],
+ };
+
let src = match src_l.contiguous_offsets() {
- None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
+ None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
- let dst_dims = l1.dims();
+ let dst_dims = dst_l.dims();
let dst_dim_len = dst_dims[dim];
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
@@ -591,7 +778,11 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let start_ids_idx = start_ids_idx + i * ids_right_len;
for right_i in 0..dst_right_len {
let ids_idx = start_ids_idx + right_i;
- let index = ids[ids_idx].as_usize();
+ let index = ids[ids_idx];
+ if index == I::max_value() {
+ continue;
+ }
+ let index = index.as_usize();
if index >= dst_dim_len {
Err(Error::InvalidIndex {
index,
@@ -601,12 +792,12 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
.bt())?
}
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
- dst[dst_idx] += src[ids_idx]
+ M::f(&mut dst[dst_idx], src[ids_idx])
}
}
}
- Ok(dst)
+ Ok(())
}
}
@@ -615,7 +806,7 @@ struct IndexAdd<'a, I: IntDType> {
dim: usize,
}
-impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
+impl Map2 for IndexAdd<'_, I> {
const OP: &'static str = "index-add";
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
// v1, l1 -> self
@@ -634,6 +825,9 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
let post_dim = src_l.dims()[dim + 1..].iter().product::();
if dim == 0 {
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
+ if *dst_idx == I::max_value() {
+ continue;
+ }
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
@@ -652,6 +846,9 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
}
} else {
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
+ if *dst_idx == I::max_value() {
+ continue;
+ }
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
@@ -735,7 +932,7 @@ fn copy_strided_src_(src: &[T], dst: &mut [T], dst_offset: usize, src_l
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
-impl<'a> Map2 for Conv1D<'a> {
+impl Map2 for Conv1D<'_> {
const OP: &'static str = "conv1d";
fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> {
let p = self.0;
@@ -959,7 +1156,7 @@ impl Map1 for Col2Im1D {
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
-impl<'a> Map2 for ConvTranspose1D<'a> {
+impl Map2 for ConvTranspose1D<'_> {
const OP: &'static str = "conv_transpose1d";
fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> {
let p = self.0;
@@ -1026,97 +1223,9 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
}
}
-struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
-
-impl<'a> Map2 for Conv2D<'a> {
- const OP: &'static str = "conv2d";
- fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> {
- let p = self.0;
- let inp = &inp[inp_l.start_offset()..];
- let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
- let k = &k[k_l.start_offset()..];
- let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
- let (out_h, out_w) = (p.out_h(), p.out_w());
-
- // Output shape: [b_size, c_out, out_h, out_w].
- let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
-
- // TODO: Avoid making this copy if `inp` already has the appropriate layout.
- let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
- let cont_s0 = p.i_h * p.i_w * p.c_in;
- let cont_s1 = p.i_w * p.c_in;
- let cont_s2 = p.c_in;
- for b_idx in 0..p.b_size {
- for h_idx in 0..p.i_h {
- for w_idx in 0..p.i_w {
- for c_idx in 0..p.c_in {
- let src_idx =
- b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
- let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
- inp_cont[dst_idx] = inp[src_idx]
- }
- }
- }
- }
-
- for offset_h in 0..p.k_h {
- for offset_w in 0..p.k_w {
- (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
- let dst_idx = dst_c_idx * out_w * out_h;
- let k_cont = (0..p.c_in)
- .map(|c_in_idx| {
- k[dst_c_idx * k_s0
- + c_in_idx * k_s1
- + offset_h * k_s2
- + offset_w * k_s3]
- })
- .collect::>();
- for b_idx in 0..p.b_size {
- let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
- for dst_h in 0..out_h {
- let dst_idx = dst_idx + dst_h * out_w;
- let src_h = p.stride * dst_h + offset_h * p.dilation;
- if src_h < p.padding || src_h >= p.i_h + p.padding {
- continue;
- }
- let src_h = src_h - p.padding;
- for dst_w in 0..out_w {
- let dst_idx = dst_idx + dst_w;
- let src_w = p.stride * dst_w + offset_w * p.dilation;
- if src_w < p.padding || src_w >= p.i_w + p.padding {
- continue;
- }
- let src_w = src_w - p.padding;
- let inp_cont = &inp_cont
- [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
- assert!(inp_cont.len() >= p.c_in);
- assert!(k_cont.len() >= p.c_in);
- let mut d = T::zero();
- unsafe {
- T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
- }
- let dst_p = dst.as_ptr();
- // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
- // the different tasks so no two threads can try to write at the same
- // location.
- unsafe {
- let ptr = dst_p.add(dst_idx) as *mut T;
- *ptr += d
- }
- }
- }
- }
- });
- }
- }
-
- Ok(dst)
- }
-}
-
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
-impl<'a> Map2 for ConvTranspose2D<'a> {
+impl Map2 for ConvTranspose2D<'_> {
const OP: &'static str = "conv_transpose2d";
fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> {
let p = self.0;
@@ -1288,6 +1397,15 @@ impl Map2 for MatMul {
} else {
Parallelism::None
};
+ let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
+ // a_skip and c_skip should be updated but step is always 0 so
+ // it wouldn't matter.
+ (1, b * m, n, k)
+ } else if a_skip == 0 && b_skip == n * k {
+ (1, m, b * n, k)
+ } else {
+ (b, m, n, k)
+ };
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
@@ -1567,6 +1685,28 @@ impl CpuStorage {
.concat();
Self::U32(storages)
}
+ Self::I16(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::I16(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::I16(storages)
+ }
+ Self::I32(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::I32(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::I32(storages)
+ }
Self::I64(_) => {
let storages = storages
.iter()
@@ -1622,6 +1762,61 @@ impl CpuStorage {
.concat();
Self::F64(storages)
}
+ Self::F8E4M3(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F8E4M3(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::F8E4M3(storages)
+ }
+ Self::F6E2M3(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F6E2M3(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::F6E2M3(storages)
+ }
+ Self::F6E3M2(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F6E3M2(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::F6E3M2(storages)
+ }
+ Self::F4(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F4(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::F4(storages)
+ }
+ Self::F8E8M0(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F8E8M0(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::F8E8M0(storages)
+ }
};
Ok(s)
}
@@ -1634,11 +1829,18 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(_) => DType::U8,
Self::U32(_) => DType::U32,
+ Self::I16(_) => DType::I16,
+ Self::I32(_) => DType::I32,
Self::I64(_) => DType::I64,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
+ Self::F8E4M3(_) => DType::F8E4M3,
+ Self::F6E2M3(_) => DType::F6E2M3,
+ Self::F6E3M2(_) => DType::F6E3M2,
+ Self::F4(_) => DType::F4,
+ Self::F8E8M0(_) => DType::F8E8M0,
}
}
@@ -1841,6 +2043,226 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F64(data))
}
+ // Conversions to F8E4M3
+ (Self::U8(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
+ Ok(Self::F8E4M3(data))
+ }
+ (Self::U32(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
+ Ok(Self::F8E4M3(data))
+ }
+ (Self::I64(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
+ Ok(Self::F8E4M3(data))
+ }
+ (Self::BF16(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
+ Ok(Self::F8E4M3(data))
+ }
+ (Self::F16(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
+ Ok(Self::F8E4M3(data))
+ }
+ (Self::F32(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, F8E4M3::from_f32);
+ Ok(Self::F8E4M3(data))
+ }
+ (Self::F64(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, F8E4M3::from_f64);
+ Ok(Self::F8E4M3(data))
+ }
+ (Self::F8E4M3(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, |v| v);
+ Ok(Self::F8E4M3(data))
+ }
+ // Conversions from F8E4M3
+ (Self::F8E4M3(storage), DType::U8) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as u8);
+ Ok(Self::U8(data))
+ }
+ (Self::F8E4M3(storage), DType::U32) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as u32);
+ Ok(Self::U32(data))
+ }
+ (Self::F8E4M3(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i64);
+ Ok(Self::I64(data))
+ }
+ (Self::F8E4M3(storage), DType::BF16) => {
+ let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
+ Ok(Self::BF16(data))
+ }
+ (Self::F8E4M3(storage), DType::F16) => {
+ let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
+ Ok(Self::F16(data))
+ }
+ (Self::F8E4M3(storage), DType::F32) => {
+ let data = unary_map(storage, layout, |v| v.to_f32());
+ Ok(Self::F32(data))
+ }
+ (Self::F8E4M3(storage), DType::F64) => {
+ let data = unary_map(storage, layout, |v| v.to_f64());
+ Ok(Self::F64(data))
+ }
+ // Conversions to I16
+ (Self::U8(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v as i16);
+ Ok(Self::I16(data))
+ }
+ (Self::U32(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v as i16);
+ Ok(Self::I16(data))
+ }
+ (Self::I16(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v);
+ Ok(Self::I16(data))
+ }
+ (Self::I32(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v as i16);
+ Ok(Self::I16(data))
+ }
+ (Self::I64(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v as i16);
+ Ok(Self::I16(data))
+ }
+ (Self::BF16(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i16);
+ Ok(Self::I16(data))
+ }
+ (Self::F16(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i16);
+ Ok(Self::I16(data))
+ }
+ (Self::F32(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v as i16);
+ Ok(Self::I16(data))
+ }
+ (Self::F64(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v as i16);
+ Ok(Self::I16(data))
+ }
+ (Self::F8E4M3(storage), DType::I16) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i16);
+ Ok(Self::I16(data))
+ }
+ // Conversions to I32
+ (Self::U8(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v as i32);
+ Ok(Self::I32(data))
+ }
+ (Self::U32(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v as i32);
+ Ok(Self::I32(data))
+ }
+ (Self::I16(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v as i32);
+ Ok(Self::I32(data))
+ }
+ (Self::I32(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v);
+ Ok(Self::I32(data))
+ }
+ (Self::I64(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v as i32);
+ Ok(Self::I32(data))
+ }
+ (Self::BF16(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i32);
+ Ok(Self::I32(data))
+ }
+ (Self::F16(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i32);
+ Ok(Self::I32(data))
+ }
+ (Self::F32(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v as i32);
+ Ok(Self::I32(data))
+ }
+ (Self::F64(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v as i32);
+ Ok(Self::I32(data))
+ }
+ (Self::F8E4M3(storage), DType::I32) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i32);
+ Ok(Self::I32(data))
+ }
+ // Conversions from I16
+ (Self::I16(storage), DType::U8) => {
+ let data = unary_map(storage, layout, |v| v as u8);
+ Ok(Self::U8(data))
+ }
+ (Self::I16(storage), DType::U32) => {
+ let data = unary_map(storage, layout, |v| v as u32);
+ Ok(Self::U32(data))
+ }
+ (Self::I16(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v as i64);
+ Ok(Self::I64(data))
+ }
+ (Self::I16(storage), DType::BF16) => {
+ let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
+ Ok(Self::BF16(data))
+ }
+ (Self::I16(storage), DType::F16) => {
+ let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
+ Ok(Self::F16(data))
+ }
+ (Self::I16(storage), DType::F32) => {
+ let data = unary_map(storage, layout, |v| v as f32);
+ Ok(Self::F32(data))
+ }
+ (Self::I16(storage), DType::F64) => {
+ let data = unary_map(storage, layout, |v| v as f64);
+ Ok(Self::F64(data))
+ }
+ (Self::I16(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
+ Ok(Self::F8E4M3(data))
+ }
+ // Conversions from I32
+ (Self::I32(storage), DType::U8) => {
+ let data = unary_map(storage, layout, |v| v as u8);
+ Ok(Self::U8(data))
+ }
+ (Self::I32(storage), DType::U32) => {
+ let data = unary_map(storage, layout, |v| v as u32);
+ Ok(Self::U32(data))
+ }
+ (Self::I32(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v as i64);
+ Ok(Self::I64(data))
+ }
+ (Self::I32(storage), DType::BF16) => {
+ let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
+ Ok(Self::BF16(data))
+ }
+ (Self::I32(storage), DType::F16) => {
+ let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
+ Ok(Self::F16(data))
+ }
+ (Self::I32(storage), DType::F32) => {
+ let data = unary_map(storage, layout, |v| v as f32);
+ Ok(Self::F32(data))
+ }
+ (Self::I32(storage), DType::F64) => {
+ let data = unary_map(storage, layout, |v| v as f64);
+ Ok(Self::F64(data))
+ }
+ (Self::I32(storage), DType::F8E4M3) => {
+ let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
+ Ok(Self::F8E4M3(data))
+ }
+ // Dummy types - return error for all conversions to/from dummy types
+ (_, DType::F6E2M3) | (_, DType::F6E3M2) | (_, DType::F4) | (_, DType::F8E8M0) => {
+ Err(Error::UnsupportedDTypeForOp(dtype, "to_dtype").bt())
+ }
+ (Self::F6E2M3(_), _)
+ | (Self::F6E3M2(_), _)
+ | (Self::F4(_), _)
+ | (Self::F8E8M0(_), _) => {
+ Err(Error::UnsupportedDTypeForOp(self.dtype(), "to_dtype").bt())
+ }
}
}
@@ -1934,6 +2356,25 @@ impl BackendStorage for CpuStorage {
UpsampleNearest2D(h, w).map(self, layout)
}
+ fn upsample_bilinear2d(
+ &self,
+ layout: &Layout,
+ h: usize,
+ w: usize,
+ align_corners: bool,
+ scale_h: Option,
+ scale_w: Option,
+ ) -> Result {
+ UpsampleBilinear2D {
+ target_h: h,
+ target_w: w,
+ align_corners,
+ scale_h_factor: scale_h,
+ scale_w_factor: scale_w,
+ }
+ .map(self, layout)
+ }
+
fn powf(&self, layout: &Layout, e: f64) -> Result {
use num_traits::Float;
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
@@ -1954,9 +2395,19 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v.powf(e));
Ok(Self::F64(data))
}
- Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
- Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
- Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
+ Self::F8E4M3(storage) => {
+ let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
+ Ok(Self::F8E4M3(data))
+ }
+ Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()),
+ Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "powf").bt()),
+ Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "powf").bt()),
+ Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "powf").bt()),
+ Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "powf").bt()),
+ Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "powf").bt()),
+ Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "powf").bt()),
+ Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "powf").bt()),
+ Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "powf").bt()),
}
}
@@ -1979,9 +2430,19 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| elu(v, alpha));
Ok(Self::F64(data))
}
+ Self::F8E4M3(storage) => {
+ let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
+ Ok(Self::F8E4M3(data))
+ }
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
+ Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()),
+ Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
+ Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "elu").bt()),
+ Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "elu").bt()),
+ Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "elu").bt()),
+ Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "elu").bt()),
}
}
@@ -2031,10 +2492,26 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, B::u32);
Ok(Self::U32(data))
}
+ Self::I16(storage) => {
+ let data = unary_map(storage, layout, B::i16);
+ Ok(Self::I16(data))
+ }
+ Self::I32(storage) => {
+ let data = unary_map(storage, layout, B::i32);
+ Ok(Self::I32(data))
+ }
Self::I64(storage) => {
let data = unary_map(storage, layout, B::i64);
Ok(Self::I64(data))
}
+ Self::F8E4M3(storage) => {
+ let data = unary_map(storage, layout, B::f8e4m3);
+ Ok(Self::F8E4M3(data))
+ }
+ Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "unary").bt()),
+ Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "unary").bt()),
+ Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "unary").bt()),
+ Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "unary").bt()),
}
}
@@ -2085,6 +2562,14 @@ impl BackendStorage for CpuStorage {
};
Ok(Self::U32(data))
}
+ (Self::I16(lhs), Self::I16(rhs)) => {
+ let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i16);
+ Ok(Self::I16(data))
+ }
+ (Self::I32(lhs), Self::I32(rhs)) => {
+ let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i32);
+ Ok(Self::I32(data))
+ }
(Self::I64(lhs), Self::I64(rhs)) => {
let data = if B::I64_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
@@ -2101,6 +2586,10 @@ impl BackendStorage for CpuStorage {
};
Ok(Self::U8(data))
}
+ (Self::F8E4M3(lhs), Self::F8E4M3(rhs)) => {
+ let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f8e4m3);
+ Ok(Self::F8E4M3(data))
+ }
_ => {
// This should be covered by the dtype check above.
Err(Error::DTypeMismatchBinaryOp {
@@ -2128,6 +2617,12 @@ impl BackendStorage for CpuStorage {
(Self::U32(src), Self::U32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
+ (Self::I16(src), Self::I16(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::I32(src), Self::I32(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
(Self::I64(src), Self::I64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
@@ -2143,6 +2638,19 @@ impl BackendStorage for CpuStorage {
(Self::F64(src), Self::F64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
+ (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::F4(src), Self::F4(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
+ (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
(_, dst) => {
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
@@ -2159,11 +2667,26 @@ impl BackendStorage for CpuStorage {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
+ (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
+ (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
+ (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
+ copy_strided_src_(src, dst, dst_offset, src_l)
+ }
+ (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
+ copy_strided_src_(src, dst, dst_offset, src_l)
+ }
+ (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
+ copy_strided_src_(src, dst, dst_offset, src_l)
+ }
+ (Self::F4(src), Self::F4(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
+ (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
+ copy_strided_src_(src, dst, dst_offset, src_l)
+ }
(_, dst) => {
// This should be covered by the dtype check above.
return Err(Error::DTypeMismatchBinaryOp {
@@ -2188,6 +2711,8 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
+ Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
+ Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
}
@@ -2301,46 +2826,7 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result {
- if !USE_IM2COL_CONV2D {
- return Conv2D(params).map(self, l, kernel, kernel_l);
- }
- let op = Im2Col {
- h_k: params.k_h,
- w_k: params.k_w,
- padding: params.padding,
- stride: params.stride,
- dilation: params.dilation,
- };
- let col = op.map(self, l)?;
- let b = params.b_size;
- let n = params.c_out;
- let (h_out, w_out) = (params.out_h(), params.out_w());
- let k = op.h_k * op.w_k * params.c_in;
- let m = h_out * w_out;
- let col_l = Layout::contiguous((b, m, k));
- let res = if kernel_l.is_contiguous() {
- let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
- .transpose(1, 2)?
- .broadcast_as((b, k, n))?;
- col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
- } else {
- // Make the kernel contiguous if not already the case.
- let mut kernel_c = unsafe {
- self.device()
- .alloc_uninit(kernel_l.shape(), kernel.dtype())?
- };
- kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
- let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
- .transpose(1, 2)?
- .broadcast_as((b, k, n))?;
- col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
- };
- let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
- .transpose(1, 2)?
- .transpose(1, 3)?;
- let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
- res.copy_strided_src(&mut res_t, 0, &res_l)?;
- Ok(res_t)
+ Conv2D(params).map(self, l, kernel, kernel_l)
}
fn conv_transpose2d(
@@ -2371,19 +2857,38 @@ impl BackendStorage for CpuStorage {
}
}
- fn scatter_add(
- &self,
+ fn scatter_set(
+ &mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
- ) -> Result {
+ ) -> Result<()> {
+ match ids {
+ Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
+ Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
+ Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
+ _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
+ }
+ }
+
+ fn scatter_add_set(
+ &mut self,
+ l: &Layout,
+ ids: &Self,
+ ids_l: &Layout,
+ src: &Self,
+ src_l: &Layout,
+ dim: usize,
+ ) -> Result<()> {
match ids {
- Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
- Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
- Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
+ Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
+ Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
+ Self::I16(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
+ Self::I32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
+ Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
}
}
@@ -2412,6 +2917,20 @@ impl BackendStorage for CpuStorage {
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
+ Self::I16(ids) => {
+ let ids = match ids_l.contiguous_offsets() {
+ Some((a, b)) => &ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
+ };
+ IndexAdd { ids, dim }.map(self, l, src, src_l)
+ }
+ Self::I32(ids) => {
+ let ids = match ids_l.contiguous_offsets() {
+ Some((a, b)) => &ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
+ };
+ IndexAdd { ids, dim }.map(self, l, src, src_l)
+ }
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
@@ -2444,6 +2963,64 @@ impl BackendStorage for CpuStorage {
fn to_cpu_storage(&self) -> Result {
Ok(self.clone())
}
+
+ fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
+ use crate::scalar::Scalar;
+ fn set(src: &mut [T], l: &Layout, s: T) {
+ match l.strided_blocks() {
+ crate::StridedBlocks::SingleBlock { start_offset, len } => {
+ src[start_offset..start_offset + len].fill(s)
+ }
+ crate::StridedBlocks::MultipleBlocks {
+ block_start_index,
+ block_len: 1,
+ } => {
+ for src_index in block_start_index {
+ src[src_index] = s
+ }
+ }
+ crate::StridedBlocks::MultipleBlocks {
+ block_start_index,
+ block_len,
+ } => {
+ for src_index in block_start_index {
+ src[src_index..src_index + block_len].fill(s)
+ }
+ }
+ }
+ }
+ match (self, s) {
+ (Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
+ (Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
+ (Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
+ (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
+ (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
+ (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
+ (Self::I16(storage), Scalar::I16(v)) => set(storage, l, v),
+ (Self::I32(storage), Scalar::I32(v)) => set(storage, l, v),
+ (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
+ (Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v),
+ // Dummy types don't support scalar operations
+ (Self::F6E2M3(_), _) => {
+ crate::bail!("const_set not supported for dummy type F6E2M3")
+ }
+ (Self::F6E3M2(_), _) => {
+ crate::bail!("const_set not supported for dummy type F6E3M2")
+ }
+ (Self::F4(_), _) => {
+ crate::bail!("const_set not supported for dummy type F4")
+ }
+ (Self::F8E8M0(_), _) => {
+ crate::bail!("const_set not supported for dummy type F8E8M0")
+ }
+ (st, s) => crate::bail!(
+ "const_set dtype mismatch, expected {:?} but got {:?}",
+ st.dtype(),
+ s
+ ),
+ }
+ Ok(())
+ }
}
impl BackendDevice for CpuDevice {
@@ -2477,19 +3054,29 @@ impl BackendDevice for CpuDevice {
crate::bail!("cannot seed the CPU rng with set_seed")
}
+ fn get_current_seed(&self) -> Result {
+ crate::bail!("cannot get the CPU rng seed with get_current_seed")
+ }
+
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result {
use rand::prelude::*;
let elem_count = shape.elem_count();
- let mut rng = rand::thread_rng();
+ let mut rng = rand::rng();
match dtype {
- DType::U8 | DType::U32 | DType::I64 => {
- Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
- }
+ DType::U8
+ | DType::U32
+ | DType::I16
+ | DType::I32
+ | DType::I64
+ | DType::F6E2M3
+ | DType::F6E3M2
+ | DType::F4
+ | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
- let uniform =
- rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
+ let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::(uniform))
}
@@ -2497,16 +3084,27 @@ impl BackendDevice for CpuDevice {
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
- let uniform =
- rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
+ let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::(uniform))
}
Ok(CpuStorage::F16(data))
}
+ DType::F8E4M3 => {
+ let mut data = Vec::with_capacity(elem_count);
+ let uniform =
+ rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
+ .map_err(Error::wrap)?;
+ for _i in 0..elem_count {
+ data.push(rng.sample::(uniform))
+ }
+ Ok(CpuStorage::F8E4M3(data))
+ }
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
- let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
+ let uniform =
+ rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::(uniform))
}
@@ -2514,7 +3112,7 @@ impl BackendDevice for CpuDevice {
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
- let uniform = rand::distributions::Uniform::new(min, max);
+ let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::(uniform))
}
@@ -2527,11 +3125,17 @@ impl BackendDevice for CpuDevice {
use rand::prelude::*;
let elem_count = shape.elem_count();
- let mut rng = rand::thread_rng();
+ let mut rng = rand::rng();
match dtype {
- DType::U8 | DType::U32 | DType::I64 => {
- Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
- }
+ DType::U8
+ | DType::U32
+ | DType::I16
+ | DType::I32
+ | DType::I64
+ | DType::F6E2M3
+ | DType::F6E3M2
+ | DType::F4
+ | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
@@ -2550,6 +3154,15 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
+ DType::F8E4M3 => {
+ let mut data = Vec::with_capacity(elem_count);
+ let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
+ .map_err(Error::wrap)?;
+ for _i in 0..elem_count {
+ data.push(normal.sample(&mut rng))
+ }
+ Ok(CpuStorage::F8E4M3(data))
+ }
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let normal =
@@ -2588,6 +3201,16 @@ impl BackendDevice for CpuDevice {
v.set_len(elem_count);
CpuStorage::U32(v)
}
+ DType::I16 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::I16(v)
+ }
+ DType::I32 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::I32(v)
+ }
DType::I64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
@@ -2613,20 +3236,14 @@ impl BackendDevice for CpuDevice {
v.set_len(elem_count);
CpuStorage::F64(v)
}
- };
- Ok(storage)
- }
-
- fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result {
- let elem_count = shape.elem_count();
- let storage = match dtype {
- DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
- DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
- DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
- DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
- DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
- DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
- DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
+ DType::F8E4M3 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::F8E4M3(v)
+ }
+ DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
+ return Err(Error::UnsupportedDTypeForOp(dtype, "alloc_uninit").bt())
+ }
};
Ok(storage)
}
@@ -2636,11 +3253,17 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
+ DType::I16 => CpuStorage::I16(vec![0i16; elem_count]),
+ DType::I32 => CpuStorage::I32(vec![0i32; elem_count]),
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
+ DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
+ DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
+ return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt())
+ }
};
Ok(storage)
}
diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs
index 3e0c69b4f7..1f800a928b 100644
--- a/candle-core/src/cpu_backend/utils.rs
+++ b/candle-core/src/cpu_backend/utils.rs
@@ -10,11 +10,19 @@ pub trait Map1 {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
+ C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)),
+ C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
+ C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)),
+ // Dummy types don't support Map1 operations
+ C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
+ C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
+ C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
+ C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
}
}
}
@@ -26,11 +34,19 @@ pub trait Map1Any {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
+ C::I16(vs) => Ok(self.f(vs, layout, C::I16)?),
+ C::I32(vs) => Ok(self.f(vs, layout, C::I32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
+ C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?),
+ // Dummy types don't support Map1Any operations
+ C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
+ C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
+ C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
+ C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
}
}
}
@@ -43,11 +59,14 @@ pub trait Map2 {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
+ (C::I16(v1), C::I16(v2)) => Ok(C::I16(self.f(v1, l1, v2, l2)?)),
+ (C::I32(v1), C::I32(v2)) => Ok(C::I32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
+ (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
@@ -58,6 +77,33 @@ pub trait Map2 {
}
}
+pub trait Map2InPlace {
+ const OP: &'static str;
+ fn f(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
+
+ fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
+ match (v1, v2) {
+ (C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::I16(v1), C::I16(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::I32(v1), C::I32(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
+ (C::F8E4M3(v1), C::F8E4M3(v2)) => self.f(v1, l1, v2, l2)?,
+ (v1, v2) => Err(Error::DTypeMismatchBinaryOp {
+ lhs: v1.dtype(),
+ rhs: v2.dtype(),
+ op: Self::OP,
+ }
+ .bt())?,
+ };
+ Ok(())
+ }
+}
+
pub trait Map2U8 {
const OP: &'static str;
fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>;
@@ -66,11 +112,14 @@ pub trait Map2U8 {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::I16(v1), C::I16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::I32(v1), C::I32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
diff --git a/candle-core/src/cuda_backend/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs
index f5b4db9026..d7d8770587 100644
--- a/candle-core/src/cuda_backend/cudnn.rs
+++ b/candle-core/src/cuda_backend/cudnn.rs
@@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
- let c = Cudnn::new(dev.cuda_device());
+ let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
@@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv2d.get_workspace_size(alg)?;
- let mut workspace = dev.cuda_device().alloc_zeros::(workspace_size)?;
+ let mut workspace = dev.cuda_stream().alloc_zeros::(workspace_size)?;
unsafe {
conv2d.launch::, _, _, _>(
alg,
@@ -122,3 +122,104 @@ pub(crate) fn launch_conv2d<
}
Ok(())
}
+
+pub(crate) fn launch_conv1d<
+ T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
+ Y: cudarc::cudnn::CudnnDataType,
+>(
+ src: &CudaView,
+ src_l: &crate::Layout,
+ filter: &CudaView,
+ dst: &mut CudaSlice,
+ params: &crate::conv::ParamsConv1D,
+ dev: &crate::cuda_backend::CudaDevice,
+) -> crate::Result<()> {
+ use crate::conv::CudnnFwdAlgo as CandleAlgo;
+ use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
+
+ let device_id = dev.id();
+ let cudnn = CUDNN.with(|cudnn| {
+ if let Some(cudnn) = cudnn.borrow().get(&device_id) {
+ return Ok(cudnn.clone());
+ }
+ let c = Cudnn::new(dev.cuda_stream());
+ if let Ok(c) = &c {
+ cudnn.borrow_mut().insert(device_id, c.clone());
+ }
+ c
+ })?;
+ let conv = cudnn.create_conv2d::(
+ /* pad */ [params.padding as i32, 0],
+ /* stride */ [params.stride as i32, 1],
+ /* dilation */ [params.dilation as i32, 1],
+ cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
+ )?;
+ // https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
+ // > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
+ // > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
+ // > recommended that the user create a 4D tensor, and set the size along unused dimensions
+ // > to 1.
+ let x_shape = [
+ params.b_size as i32,
+ params.c_in as i32,
+ params.l_in as i32,
+ 1,
+ ];
+ // Note that `src` already starts at the proper offset.
+ let x = if src_l.is_contiguous() {
+ cudnn.create_4d_tensor::(
+ cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
+ x_shape,
+ )?
+ } else {
+ let s = src_l.stride();
+ cudnn.create_4d_tensor_ex::(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
+ };
+ let w = cudnn.create_4d_filter::(
+ cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
+ [
+ params.c_out as i32,
+ params.c_in as i32,
+ params.k_size as i32,
+ 1,
+ ],
+ )?;
+ let l_out = params.l_out() as i32;
+ let y = cudnn.create_4d_tensor::(
+ cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
+ [params.b_size as i32, params.c_out as i32, l_out, 1],
+ )?;
+ let conv1d = ConvForward {
+ conv: &conv,
+ x: &x,
+ w: &w,
+ y: &y,
+ };
+ let alg = match params.cudnn_fwd_algo {
+ None => conv1d.pick_algorithm()?,
+ Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
+ Some(CandleAlgo::ImplicitPrecompGemm) => {
+ A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
+ }
+ Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
+ Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
+ Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
+ Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
+ Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
+ Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
+ Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
+ };
+ let workspace_size = conv1d.get_workspace_size(alg)?;
+ let mut workspace = dev.cuda_stream().alloc_zeros::(workspace_size)?;
+ unsafe {
+ conv1d.launch::, _, _, _>(
+ alg,
+ Some(&mut workspace),
+ (T::one(), T::zero()),
+ src,
+ filter,
+ dst,
+ )?;
+ }
+ Ok(())
+}
diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs
index d3bd29030e..425fd74f76 100644
--- a/candle-core/src/cuda_backend/device.rs
+++ b/candle-core/src/cuda_backend/device.rs
@@ -1,10 +1,12 @@
-use crate::backend::BackendDevice;
+use crate::backend::{BackendDevice, BackendStorage};
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
-use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
+use cudarc::driver::CudaFunction;
+use float8::F8E4M3;
use half::{bf16, f16};
-use std::sync::{Arc, Mutex};
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex, RwLock};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
@@ -24,12 +26,20 @@ impl DeviceId {
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
+pub struct ModuleStore {
+ mdls: [Option>; kernels::ALL_IDS.len()],
+}
+
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
- device: Arc,
+ context: Arc,
+ modules: Arc>,
+ custom_modules: Arc>>>,
+ stream: Arc,
pub(crate) blas: Arc,
curand: Arc>,
+ seed_value: Arc>,
}
impl std::fmt::Debug for CudaDevice {
@@ -38,143 +48,225 @@ impl std::fmt::Debug for CudaDevice {
}
}
-impl std::ops::Deref for CudaDevice {
- type Target = Arc;
+impl CudaDevice {
+ #[allow(clippy::missing_safety_doc)]
+ pub unsafe fn alloc(
+ &self,
+ len: usize,
+ ) -> Result> {
+ self.stream.alloc::(len).w()
+ }
+
+ pub fn alloc_zeros(
+ &self,
+ len: usize,
+ ) -> Result> {
+ self.stream.alloc_zeros::(len).w()
+ }
+
+ pub fn memcpy_htod<
+ T: cudarc::driver::DeviceRepr,
+ Src: cudarc::driver::HostSlice + ?Sized,
+ Dst: cudarc::driver::DevicePtrMut,
+ >(
+ &self,
+ src: &Src,
+ dst: &mut Dst,
+ ) -> Result<()> {
+ self.stream.memcpy_htod(src, dst).w()
+ }
+
+ pub fn clone_dtoh>(
+ &self,
+ src: &Src,
+ ) -> Result> {
+ self.stream.clone_dtoh(src).w()
+ }
+
+ pub fn memcpy_dtod<
+ T,
+ Src: cudarc::driver::DevicePtr,
+ Dst: cudarc::driver::DevicePtrMut,
+ >(
+ &self,
+ src: &Src,
+ dst: &mut Dst,
+ ) -> Result<()> {
+ self.stream.memcpy_dtod(src, dst).w()
+ }
+
+ pub fn memcpy_dtoh<
+ T: cudarc::driver::DeviceRepr,
+ Src: cudarc::driver::DevicePtr,
+ Dst: cudarc::driver::HostSlice,
+ >(
+ &self,
+ src: &Src,
+ dst: &mut Dst,
+ ) -> Result<()> {
+ self.stream.memcpy_dtoh(src, dst).w()
+ }
+
+ pub fn clone_htod + ?Sized>(
+ &self,
+ src: &Src,
+ ) -> Result> {
+ self.stream.clone_htod(src).w()
+ }
+}
+
+pub struct CudaFunc {
+ func: CudaFunction,
+ stream: Arc,
+}
+
+impl std::ops::Deref for CudaFunc {
+ type Target = CudaFunction;
fn deref(&self) -> &Self::Target {
- &self.device
+ &self.func
+ }
+}
+
+impl CudaFunc {
+ pub fn into_cuda_function(self) -> CudaFunction {
+ self.func
+ }
+}
+
+#[macro_export]
+macro_rules! builder_arg {
+ ($b:ident, $($arg:expr),*) => {
+ $(
+ let __arg = $arg;
+ $b.arg(&__arg);
+ )*
+ };
+}
+
+impl CudaFunc {
+ pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
+ self.stream.launch_builder(&self.func)
}
}
impl CudaDevice {
- pub fn cuda_device(&self) -> Arc {
- self.device.clone()
+ pub fn cuda_stream(&self) -> Arc {
+ self.stream.clone()
+ }
+
+ /// When turned on, all cuda tensors **created after calling this function** will
+ /// not track uses via cuda events.
+ ///
+ /// # Safety
+ ///
+ /// It is up to the user to ensure proper synchronization between multiple streams:
+ /// - Ensure that no tensor is freed before a use on another stream is finished.
+ /// - Ensure that a tensor is not used on another stream before allocation on the
+ /// allocating stream finishes.
+ /// - Ensure that a tensor is not written two concurrently by multiple streams.
+ pub unsafe fn disable_event_tracking(&self) {
+ self.context.disable_event_tracking()
+ }
+
+ pub fn is_event_tracking(&self) -> bool {
+ self.context.is_event_tracking()
}
+ #[cfg(all(feature = "ug", not(target_arch = "wasm32")))]
pub fn compile(
&self,
func_name: &'static str,
- kernel: ug::lang::ssa::Kernel,
- ) -> Result {
+ kernel: candle_ug::lang::ssa::Kernel,
+ ) -> Result {
let mut buf = vec![];
- ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
+ candle_ug::cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
- self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
- let func = match self.device.get_func("ug", func_name) {
- Some(func) => func,
- None => crate::bail!("unknown function ug::{func_name}"),
- };
- Ok(func)
+ let module = self.context.load_module(ptx).w()?;
+ let func = module.load_function(func_name).w()?;
+ Ok(CudaFunc {
+ func,
+ stream: self.stream.clone(),
+ })
}
pub fn id(&self) -> DeviceId {
self.id
}
- fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result {
- let elem_count = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(elem_count as u32);
- let slice = match dtype {
- DType::U8 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::(elem_count) }.w()?;
- let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
- let params = (&data, v as u8, elem_count);
- unsafe { func.launch(cfg, params) }.w()?;
- CudaStorageSlice::U8(data)
- }
- DType::U32 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::(elem_count) }.w()?;
- let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
- let params = (&data, v as u32, elem_count);
- unsafe { func.launch(cfg, params) }.w()?;
- CudaStorageSlice::U32(data)
- }
- DType::I64 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::(elem_count) }.w()?;
- let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
- let params = (&data, v as i64, elem_count);
- unsafe { func.launch(cfg, params) }.w()?;
- CudaStorageSlice::I64(data)
- }
- DType::BF16 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::(elem_count) }.w()?;
- let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
- let params = (&data, bf16::from_f64(v), elem_count);
- unsafe { func.launch(cfg, params) }.w()?;
- CudaStorageSlice::BF16(data)
- }
- DType::F16 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::(elem_count) }.w()?;
- let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
- let params = (&data, f16::from_f64(v), elem_count);
- unsafe { func.launch(cfg, params) }.w()?;
- CudaStorageSlice::F16(data)
- }
- DType::F32 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::(elem_count) }.w()?;
- let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
- let params = (&data, v as f32, elem_count);
- unsafe { func.launch(cfg, params) }.w()?;
- CudaStorageSlice::F32(data)
- }
- DType::F64 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::(elem_count) }.w()?;
- let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
- let params = (&data, v, elem_count);
- unsafe { func.launch(cfg, params) }.w()?;
- CudaStorageSlice::F64(data)
- }
- };
- Ok(CudaStorage {
- slice,
- device: self.clone(),
+ pub fn get_or_load_custom_func(
+ &self,
+ fn_name: &str,
+ module_name: &str,
+ ptx: &str,
+ ) -> Result {
+ let ms = self.custom_modules.read().unwrap();
+ if let Some(mdl) = ms.get(module_name).as_ref() {
+ let func = mdl.load_function(fn_name).w()?;
+ return Ok(CudaFunc {
+ func,
+ stream: self.stream.clone(),
+ });
+ }
+ drop(ms);
+ let mut ms = self.custom_modules.write().unwrap();
+ let cuda_module = self.context.load_module(ptx.into()).w()?;
+ ms.insert(module_name.to_string(), cuda_module.clone());
+ let func = cuda_module.load_function(fn_name).w()?;
+ Ok(CudaFunc {
+ func,
+ stream: self.stream.clone(),
})
}
- pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result {
- if !self.has_func(module_name, module_name) {
- // Leaking the string here is a bit sad but we need a &'static str and this is only
- // done once per kernel name.
- let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
- self.load_ptx(ptx.into(), module_name, &[static_module_name])
- .map_err(|cuda| CudaError::Load {
- cuda,
- module_name: module_name.to_string(),
- })
- .w()?;
+ pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result {
+ let ms = self.modules.read().unwrap();
+ if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
+ let func = mdl.load_function(fn_name).w()?;
+ return Ok(CudaFunc {
+ func,
+ stream: self.stream.clone(),
+ });
}
- self.get_func(module_name, module_name)
- // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
- // able to only build the error value if needed.
- .ok_or(CudaError::MissingKernel {
- module_name: module_name.to_string(),
- })
- .w()
+ drop(ms);
+ let mut ms = self.modules.write().unwrap();
+ let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
+ ms.mdls[mdl.index()] = Some(cuda_module.clone());
+ let func = cuda_module.load_function(fn_name).w()?;
+ Ok(CudaFunc {
+ func,
+ stream: self.stream.clone(),
+ })
+ }
+
+ pub fn cublas_handle(&self) -> Arc {
+ self.blas.clone()
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result {
- let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
- let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
- let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
+ let context = cudarc::driver::CudaContext::new(ordinal).w()?;
+ let stream = context.new_stream().w()?;
+ let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
+ let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
+ let module_store = ModuleStore {
+ mdls: [const { None }; kernels::ALL_IDS.len()],
+ };
Ok(Self {
id: DeviceId::new(),
- device,
+ context,
+ stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
+ modules: Arc::new(std::sync::RwLock::new(module_store)),
+ custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
+ seed_value: Arc::new(RwLock::new(299792458)),
})
}
}
@@ -183,14 +275,22 @@ impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result {
- let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
- let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
- let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
+ let context = cudarc::driver::CudaContext::new(ordinal).w()?;
+ let stream = context.default_stream();
+ let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
+ let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
+ let module_store = ModuleStore {
+ mdls: [const { None }; kernels::ALL_IDS.len()],
+ };
Ok(Self {
id: DeviceId::new(),
- device,
+ context,
+ stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
+ modules: Arc::new(std::sync::RwLock::new(module_store)),
+ custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
+ seed_value: Arc::new(RwLock::new(299792458)),
})
}
@@ -198,13 +298,18 @@ impl BackendDevice for CudaDevice {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
- curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
+ curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
+ *self.seed_value.write().unwrap() = seed;
Ok(())
}
+ fn get_current_seed(&self) -> Result {
+ Ok(*self.seed_value.read().unwrap())
+ }
+
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
- gpu_id: self.device.ordinal(),
+ gpu_id: self.context.ordinal(),
}
}
@@ -216,33 +321,50 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
- let data = self.alloc_zeros::(elem_count).w()?;
+ let data = self.alloc_zeros::(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
- let data = self.alloc_zeros::(elem_count).w()?;
+ let data = self.alloc_zeros::(elem_count)?;
CudaStorageSlice::U32(data)
}
+ DType::I16 => {
+ let data = self.alloc_zeros::(elem_count)?;
+ CudaStorageSlice::I16(data)
+ }
+ DType::I32 => {
+ let data = self.alloc_zeros::(elem_count)?;
+ CudaStorageSlice::I32(data)
+ }
DType::I64 => {
- let data = self.alloc_zeros::(elem_count).w()?;
+ let data = self.alloc_zeros::(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
- let data = self.alloc_zeros::(elem_count).w()?;
+ let data = self.alloc_zeros::(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
- let data = self.alloc_zeros::(elem_count).w()?;
+ let data = self.alloc_zeros::(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
- let data = self.alloc_zeros::(elem_count).w()?;
+ let data = self.alloc_zeros::(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let data = self.alloc_zeros::(elem_count).w()?;
+ let data = self.alloc_zeros::(elem_count)?;
CudaStorageSlice::F64(data)
}
+ DType::F8E4M3 => {
+ let data = self.alloc_zeros::(elem_count)?;
+ CudaStorageSlice::F8E4M3(data)
+ }
+ DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
+ return Err(
+ CudaError::InternalError("Dummy types not supported in CUDA backend").into(),
+ )
+ }
};
Ok(CudaStorage {
slice,
@@ -256,23 +378,34 @@ impl BackendDevice for CudaDevice {
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
- DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
- Err(CudaError::UnsupportedDtype {
- dtype,
- op: "rand_uniform",
- })
- .w()?
- }
+ DType::U8
+ | DType::U32
+ | DType::I16
+ | DType::I32
+ | DType::I64
+ | DType::F16
+ | DType::BF16 => Err(CudaError::UnsupportedDtype {
+ dtype,
+ op: "rand_uniform",
+ })
+ .w()?,
DType::F32 => {
- let mut data = unsafe { self.alloc::(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let mut data = unsafe { self.alloc::(elem_count) }.w()?;
+ let mut data = unsafe { self.alloc::(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
+ DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
+ Err(CudaError::UnsupportedDtype {
+ dtype,
+ op: "rand_uniform",
+ })
+ .w()?
+ }
};
let slice = if lo == 0. && up == 1.0 {
slice
@@ -300,15 +433,19 @@ impl BackendDevice for CudaDevice {
elem_count
};
let slice = match dtype {
- DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
- Err(CudaError::UnsupportedDtype {
- dtype,
- op: "rand_normal",
- })
- .w()?
- }
+ DType::U8
+ | DType::U32
+ | DType::I16
+ | DType::I32
+ | DType::I64
+ | DType::F16
+ | DType::BF16 => Err(CudaError::UnsupportedDtype {
+ dtype,
+ op: "rand_normal",
+ })
+ .w()?,
DType::F32 => {
- let mut data = unsafe { self.alloc::(elem_count_round) }.w()?;
+ let mut data = unsafe { self.alloc::(elem_count_round)? };
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@@ -316,10 +453,17 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let mut data = unsafe { self.alloc::(elem_count_round) }.w()?;
+ let mut data = unsafe { self.alloc::(elem_count_round)? };
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
+ DType::F8E4M3 | DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
+ Err(CudaError::UnsupportedDtype {
+ dtype,
+ op: "rand_normal",
+ })
+ .w()?
+ }
};
Ok(CudaStorage {
slice,
@@ -327,41 +471,54 @@ impl BackendDevice for CudaDevice {
})
}
- fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result {
- self.const_impl(1., shape, dtype)
- }
-
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
- let data = self.alloc::