Skip to content

Conversation

@antimora
Copy link
Collaborator

Adds minilm-burn crate implementing the all-MiniLM-L12-v2 sentence transformer model.

Features

  • Load pretrained weights from HuggingFace with simple API: MiniLmModel::pretrained(&device)
  • Mean pooling and L2 normalization for sentence embeddings
  • Multi-backend support: ndarray, wgpu, tch-cpu, tch-gpu, cuda
  • Config loaded from HuggingFace's config.json via serde

Usage

let (model, tokenizer) = MiniLmModel::<B>::pretrained(&device)?;
let output = model.forward(input_ids, attention_mask.clone(), None);
let embeddings = mean_pooling(output.hidden_states, attention_mask);
let embeddings = normalize_l2(embeddings);

Benchmarks (Apple M3 Max)

Benchmark ndarray wgpu tch-cpu
forward (batch=1) 102 ms 35 ms 26 ms
forward (batch=16) 1.54 s 73 ms 130 ms

Testing

  • Unit tests: cargo test --features ndarray
  • Integration tests verify outputs match Python sentence-transformers within 1e-4 tolerance

Implements the all-MiniLM-L12-v2 model using Burn's built-in
TransformerEncoder and burn-store for weight loading from safetensors.

- Load config from HuggingFace's config.json via serde
- Key remapping from HuggingFace BERT to Burn TransformerEncoder
- Mean pooling for sentence embeddings
- Example with HuggingFace download and cosine similarity
Reformatted code in loader.rs, model.rs, and pooling.rs for improved readability and consistency. Adjusted import order and indentation, and expanded some array initializations for clarity in tests. No functional changes were made.
Results measured on Apple M3 Max showing performance comparison
across all supported backends.
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Introduces a new minilm-burn crate implementing the all-MiniLM-L12-v2 sentence-transformer model on top of Burn, with support for multiple backends, pretrained weight loading from Hugging Face, and documentation/examples/benchmarks.

Changes:

  • Add MiniLM-specific embedding, encoder, pooling, and normalization modules plus a MiniLmModel configuration and forward pass.
  • Implement HF Hub-based weight and tokenizer loading, along with a pretrained convenience API, examples, and benchmarks across backends.
  • Add integration tests against Python sentence-transformers outputs and update repo-level documentation/README to list the new model.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
minilm-burn/src/embedding.rs Defines MiniLmEmbeddingsConfig and MiniLmEmbeddings (word/position/token-type embeddings + layer norm + dropout) matching the MiniLM/BERT-style embedding stack.
minilm-burn/src/model.rs Adds MiniLmConfig, MiniLmModel, and MiniLmOutput, wiring Burn’s TransformerEncoder to MiniLM’s config and attention mask semantics.
minilm-burn/src/pooling.rs Implements mean_pooling and normalize_l2 utilities plus a unit test for mean pooling on the ndarray backend.
minilm-burn/src/loader.rs Introduces LoadError, HF safetensor key remapping and loading, HF Hub download utilities, config loading, and MiniLmModel::pretrained.
minilm-burn/src/lib.rs Exposes the MiniLM public API and adds crate-level documentation and a usage example.
minilm-burn/tests/integration_test.rs Adds ndarray-based integration tests that compare MiniLM Rust embeddings and cosine similarities against Python sentence-transformers references.
minilm-burn/scripts/generate_reference.py Script to generate reference embeddings and cosine similarities from Python sentence-transformers for use in integration tests.
minilm-burn/scripts/debug_embeddings.py Small helper script to inspect raw MiniLM embeddings and norms in Python for debugging.
minilm-burn/examples/inference.rs Demonstrates end-to-end inference with the pretrained MiniLM model, tokenization, pooling, and cosine similarity computation on the ndarray backend.
minilm-burn/benches/inference.rs Adds Criterion benchmarks for forward passes, batching, full pipeline, and pooling/normalization across multiple backends.
minilm-burn/README.md Documents the new crate’s usage, features, testing strategy, and benchmark results.
minilm-burn/Cargo.toml Declares the new crate, its features (including multi-backend and pretrained support), and dependencies (Burn, burn-store, tokenizers, hf-hub, tokio, etc.).
README.md Updates the root repository overview and tables to include the MiniLM model and its subcrate, and switches to reference-style links.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

- Fix doc example to use MiniLmModel::pretrained (not MiniLmConfig)
- Update HfModelFiles doc to reflect struct with 3 fields
- Fix generate_reference.py to use normalize_embeddings=True
- Use dirs::cache_dir() for platform-appropriate default location
- Allow custom cache path via pretrained(device, Some(path))
- Downloads to ~/.cache/burn-models/ (Linux) or ~/Library/Caches/burn-models/ (macOS)
- Remove hardcoded hidden_size (384), derive from tensor dims
- Add normalize_l2 to example (matches sentence-transformers default)
- Remove debug_embeddings.py script
PyTorchToBurnAdapter handles weight→gamma and bias→beta automatically.
- Add MiniLmVariant enum (L6, L12) for model selection
- L6: 6 layers, faster inference
- L12: 12 layers, better quality (default)
- Update pretrained() to accept variant parameter
L6 is ~2x faster than L12 across all backends:
- ndarray: 53ms vs 105ms
- wgpu: 18ms vs 35ms
- tch-cpu: 14ms vs 27ms
Replaces use of `equal_elem(0)` with comparison to a zeros tensor for creating the padding mask. This ensures compatibility with tensor operations and device placement.
Refactored lines where MiniLmModel is loaded to improve code readability by reducing line length and aligning with Rust formatting conventions. No functional changes were made.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant