GPU-accelerated probabilistic graphical models in Rust
thrml-rs is a pure Rust implementation of GPU-accelerated sampling for probabilistic graphical models (PGMs),
ported from Extropic's THRML library, with a few tweaks.
- GPU Acceleration: Multiple backend support:
- WGPU (default): Metal (macOS), Vulkan (Linux/Windows)
- CUDA: Native NVIDIA GPU support
- Block Gibbs Sampling: Efficient parallel sampling for PGMs
- Energy-Based Models: Ising models, discrete EBMs, Gaussian PGMs
- Mixed Variable Types: Spin, categorical, and continuous nodes
- Deterministic RNG: Reproducible sampling with ChaCha8-based key splitting
- Moment Estimation: Built-in observers for computing statistics
- Training Support: Contrastive divergence, KL gradient estimation
use thrml_core::{Node, NodeType, Block, backend::init_gpu_device};
use thrml_models::ising::{IsingEBM, IsingSamplingProgram, hinton_init};
use thrml_samplers::{RngKey, SamplingSchedule};
use burn::tensor::Tensor;
fn main() {
// Initialize GPU
let device = init_gpu_device();
// Create a 5-node Ising chain
let nodes: Vec<Node> = (0..5).map(|_| Node::new(NodeType::Spin)).collect();
let edges: Vec<_> = nodes.windows(2)
.map(|w| (w[0].clone(), w[1].clone()))
.collect();
// Define biases and coupling weights
let biases = Tensor::from_data([0.1f32, 0.2, 0.0, -0.1, 0.3], &device);
let weights = Tensor::from_data([0.5f32, -0.3, 0.4, 0.2], &device);
let beta = Tensor::from_data([1.0f32], &device);
// Create the Ising model
let model = IsingEBM::new(nodes.clone(), edges, biases, weights, beta);
// Initialize using Hinton's method
let key = RngKey::new(42);
let blocks = vec![Block::new(nodes).unwrap()];
let init_state = hinton_init(key, &model, &blocks, &[], &device);
println!("Model initialized with {} nodes", model.nodes().len());
}| Crate | Description |
|---|---|
thrml-core |
Core types: Node, Block, BlockSpec, GPU backend |
thrml-samplers |
Sampling algorithms: Gibbs, Bernoulli, Softmax, Gaussian |
thrml-models |
Model implementations: Ising, Discrete EBM, Continuous factors |
thrml-observers |
Observation utilities: State, Moments |
thrml-examples |
Example programs and utilities |
[dependencies]
thrml-core = "0.1"
thrml-samplers = "0.1"
thrml-models = "0.1"
thrml-observers = "0.1"The sphere branch includes additional crates and features not yet published to crates.io:
[dependencies]
thrml-core = { git = "https://github.com/SashimiSaketoro/thrml-rs", branch = "sphere" }
thrml-samplers = { git = "https://github.com/SashimiSaketoro/thrml-rs", branch = "sphere" }
thrml-models = { git = "https://github.com/SashimiSaketoro/thrml-rs", branch = "sphere" }
thrml-observers = { git = "https://github.com/SashimiSaketoro/thrml-rs", branch = "sphere" }
# Experimental: hyperspherical navigation (sphere branch only)
thrml-sphere = { git = "https://github.com/SashimiSaketoro/thrml-rs", branch = "sphere" }| Feature | Backend | Use Case |
|---|---|---|
gpu (default) |
WGPU | Metal (macOS), Vulkan (Linux), DX12 (Windows) |
cuda |
CUDA + WGPU | NVIDIA GPUs with native CUDA |
cpu |
ndarray + WGPU | Development/testing without GPU, or CPU fallback |
# Default: WGPU backend (Metal on macOS, Vulkan on Linux)
cargo build --release
# Enable CUDA support alongside WGPU
cargo build --release --features cuda
# Enable CPU backend (useful for testing or systems without GPU)
cargo build --release --features cpu- Rust 1.89+ (stable) - required by Burn 0.19
- WGPU backend: GPU with Metal (macOS) or Vulkan (Linux/Windows) support
- CUDA backend: NVIDIA GPU with CUDA toolkit installed
thrml-rs is designed to run from laptops to DGX-class servers. The core crates share a
common runtime abstraction:
ComputeBackend: Selects CPU / GPU / hybrid executionPrecisionMode: Chooses betweenGpuFast,CpuPrecise, orAdaptiveroutingOpType: Tags operations (Ising sampling, distance, navigator steps) for precision-aware routing
| Tier | Examples | FP64 | Default Profile |
|---|---|---|---|
| Apple Silicon | M1–M4 Pro/Max/Ultra | CPU only | CpuFp64Strict - GPU for throughput, CPU for precision |
| Consumer GPU | RTX 3080–5090, RDNA3/4 | Weak | GpuMixed - GPU FP32, CPU f64 for corrections |
| HPC GPU | H100, H200, B200, DGX Spark | Native | GpuHpcFp64 - Full f64 on GPU |
| CPU Only | Servers without GPU | Native | CpuFp64Strict - All operations on CPU |
use thrml_core::compute::{ComputeBackend, RuntimePolicy, OpType};
// Auto-detect hardware and create appropriate backend
let policy = RuntimePolicy::detect();
let backend = ComputeBackend::from_policy(&policy);
println!("Detected: {:?}", policy.tier); // e.g., AppleSilicon
println!("Profile: {:?}", policy.profile); // e.g., CpuFp64Strict
// Precision-aware routing
if backend.use_cpu(OpType::IsingSampling, None) {
// High-precision CPU f64 path (Apple Silicon, consumer GPU)
} else {
// Fast GPU path (HPC GPUs with native f64)
}The default ComputeBackend::default() auto-detects your hardware. For explicit control:
// Force specific profiles
let apple = RuntimePolicy::apple_silicon();
let hpc = RuntimePolicy::nvidia_hopper(); // H100/H200
let spark = RuntimePolicy::nvidia_spark(); // DGX Spark / GB10| Crate | Description |
|---|---|
thrml-sphere |
Hyperspherical navigation, ROOTS indexing, multi-cone EBM |
What it does:
- SphereEBM: Langevin dynamics to place embeddings on a hypersphere
- NavigatorEBM: EBM with learnable weights for similarity, radial alignment, path length
- MultiConeNavigator: Spawns cones from ROOTS peaks, allocates budget per cone
- RootsIndex: Compresses inner shells ~3000:1 for coarse routing
Training: contrastive divergence with hard negatives, PCD, curriculum scheduling.
BLT Integration: Works with blt-burn — BLT provides embeddings, thrml-sphere provides the navigator.
use thrml_sphere::{RuntimeConfig, BudgetConfig, SphereConfig, RootsConfig};
let runtime = RuntimeConfig::auto();
println!("Hardware: {:?}, Budget: {:.1} GB", runtime.policy.tier, runtime.budget_gb());
let roots_cfg = RootsConfig::default().with_partitions(64);
let budget = runtime.budget.with_max_cones(8);For full API, see docs/api/sphere.md.
See the examples/ directory:
# Simple Ising chain demonstration
cargo run --release --example ising_chain
# Spin models with performance benchmarking
cargo run --release --example spin_models
# Categorical variable sampling with visualization
cargo run --release --example categorical_sampling
# Full API walkthrough tutorial
cargo run --release --example full_api_walkthrough
# Gaussian PGM sampling (continuous nodes)
cargo run --release --example gaussian_pgm
# Mixed Gaussian-Bernoulli model
cargo run --release --example gaussian_bernoulli_ebm
# Full MNIST training with contrastive divergence
cargo run --release --example train_mnist- API Documentation - Published crates (main branch)
- Core API - Nodes, blocks, compute backends, metrics, text utilities
- Samplers API - Gibbs, Bernoulli, Gaussian, max-cut algorithms
- Models API - Ising, discrete EBMs, continuous factors
- Sphere API - Hyperspherical navigation (sphere branch)
- BLT Pipeline - BLT → thrml-sphere integration
- Architecture Guide
- Examples README
THRML-RS leverages the Burn deep learning framework for GPU acceleration:
| Backend | Platform | GPU Support |
|---|---|---|
| WGPU-Metal | macOS | Apple Silicon, AMD, Intel |
| WGPU-Vulkan | Linux/Windows | NVIDIA, AMD, Intel |
| CUDA | Linux/Windows | NVIDIA (native) |
Key optimizations:
- Native Metal acceleration on Apple Silicon
- CUDA for maximum performance on NVIDIA GPUs
- Efficient tensor operations with automatic batching
- Fused GPU kernels for sampling operations
Contributions are welcome! Please see CONTRIBUTING.md for guidelines.
Licensed under either of:
- Apache License, Version 2.0 (LICENSE-APACHE)
- MIT license (LICENSE-MIT)
at your option.
This project is inspired by Extropic's THRML library. THRML-RS is an independent Rust implementation providing the same functionality with native GPU acceleration.