-
Notifications
You must be signed in to change notification settings - Fork 54
Add MiniLM-L12-v2 sentence transformer #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements the all-MiniLM-L12-v2 model using Burn's built-in TransformerEncoder and burn-store for weight loading from safetensors. - Load config from HuggingFace's config.json via serde - Key remapping from HuggingFace BERT to Burn TransformerEncoder - Mean pooling for sentence embeddings - Example with HuggingFace download and cosine similarity
Reformatted code in loader.rs, model.rs, and pooling.rs for improved readability and consistency. Adjusted import order and indentation, and expanded some array initializations for clarity in tests. No functional changes were made.
Results measured on Apple M3 Max showing performance comparison across all supported backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Introduces a new minilm-burn crate implementing the all-MiniLM-L12-v2 sentence-transformer model on top of Burn, with support for multiple backends, pretrained weight loading from Hugging Face, and documentation/examples/benchmarks.
Changes:
- Add MiniLM-specific embedding, encoder, pooling, and normalization modules plus a
MiniLmModelconfiguration and forward pass. - Implement HF Hub-based weight and tokenizer loading, along with a
pretrainedconvenience API, examples, and benchmarks across backends. - Add integration tests against Python
sentence-transformersoutputs and update repo-level documentation/README to list the new model.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
minilm-burn/src/embedding.rs |
Defines MiniLmEmbeddingsConfig and MiniLmEmbeddings (word/position/token-type embeddings + layer norm + dropout) matching the MiniLM/BERT-style embedding stack. |
minilm-burn/src/model.rs |
Adds MiniLmConfig, MiniLmModel, and MiniLmOutput, wiring Burn’s TransformerEncoder to MiniLM’s config and attention mask semantics. |
minilm-burn/src/pooling.rs |
Implements mean_pooling and normalize_l2 utilities plus a unit test for mean pooling on the ndarray backend. |
minilm-burn/src/loader.rs |
Introduces LoadError, HF safetensor key remapping and loading, HF Hub download utilities, config loading, and MiniLmModel::pretrained. |
minilm-burn/src/lib.rs |
Exposes the MiniLM public API and adds crate-level documentation and a usage example. |
minilm-burn/tests/integration_test.rs |
Adds ndarray-based integration tests that compare MiniLM Rust embeddings and cosine similarities against Python sentence-transformers references. |
minilm-burn/scripts/generate_reference.py |
Script to generate reference embeddings and cosine similarities from Python sentence-transformers for use in integration tests. |
minilm-burn/scripts/debug_embeddings.py |
Small helper script to inspect raw MiniLM embeddings and norms in Python for debugging. |
minilm-burn/examples/inference.rs |
Demonstrates end-to-end inference with the pretrained MiniLM model, tokenization, pooling, and cosine similarity computation on the ndarray backend. |
minilm-burn/benches/inference.rs |
Adds Criterion benchmarks for forward passes, batching, full pipeline, and pooling/normalization across multiple backends. |
minilm-burn/README.md |
Documents the new crate’s usage, features, testing strategy, and benchmark results. |
minilm-burn/Cargo.toml |
Declares the new crate, its features (including multi-backend and pretrained support), and dependencies (Burn, burn-store, tokenizers, hf-hub, tokio, etc.). |
README.md |
Updates the root repository overview and tables to include the MiniLM model and its subcrate, and switches to reference-style links. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Fix doc example to use MiniLmModel::pretrained (not MiniLmConfig) - Update HfModelFiles doc to reflect struct with 3 fields - Fix generate_reference.py to use normalize_embeddings=True
- Use dirs::cache_dir() for platform-appropriate default location - Allow custom cache path via pretrained(device, Some(path)) - Downloads to ~/.cache/burn-models/ (Linux) or ~/Library/Caches/burn-models/ (macOS)
- Remove hardcoded hidden_size (384), derive from tensor dims - Add normalize_l2 to example (matches sentence-transformers default) - Remove debug_embeddings.py script
PyTorchToBurnAdapter handles weight→gamma and bias→beta automatically.
- Add MiniLmVariant enum (L6, L12) for model selection - L6: 6 layers, faster inference - L12: 12 layers, better quality (default) - Update pretrained() to accept variant parameter
L6 is ~2x faster than L12 across all backends: - ndarray: 53ms vs 105ms - wgpu: 18ms vs 35ms - tch-cpu: 14ms vs 27ms
Replaces use of `equal_elem(0)` with comparison to a zeros tensor for creating the padding mask. This ensures compatibility with tensor operations and device placement.
Refactored lines where MiniLmModel is loaded to improve code readability by reducing line length and aligning with Rust formatting conventions. No functional changes were made.
Adds
minilm-burncrate implementing the all-MiniLM-L12-v2 sentence transformer model.Features
MiniLmModel::pretrained(&device)config.jsonvia serdeUsage
Benchmarks (Apple M3 Max)
Testing
cargo test --features ndarray