-
Notifications
You must be signed in to change notification settings - Fork 54
Migrate burn-models to burn main branch with burn-store #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Simplifies tensor loading and model record construction by returning structs directly instead of assigning to intermediate variables. Updates dependency to use the latest 'burn' from the main branch. Cleans up redundant code and improves consistency in function returns.
Switches burn and burn-import dependencies to use the main branch from the GitHub repository. Refactors the Sampler enum to box TopP and adds a constructor for TopP sampling. Updates transformer forward method to accept mutable slices for cache. Cleans up unused code and improves code clarity in llama.rs and chat.rs.
Switched burn and burn-import dependencies to use the latest code from the tracel-ai/burn GitHub repository on the main branch. Updated the inference example to clarify type conversions for index and score extraction. Added example metadata for the inference example in Cargo.toml.
Switched burn and burn-import dependencies to use the latest main branch from GitHub. Updated multi_hot to use IndexingUpdateOp::Add and adjusted training builder to pass LearningStrategy as a parameter. Added clippy allow for large enum variant in ResidualBlock.
Switches Burn, burn-store, and burn-import dependencies to the latest main branch from GitHub. Updates build.rs to remove deprecated RecordType and half_precision usage, and changes the generated weights file extension to .bpk for compatibility with the latest Burn API.
Switched burn and burn-import dependencies to use the latest main branch from GitHub for both regular and dev dependencies. Added an example for inference requiring the pretrained feature. Introduced minor code improvements: allowed large enum variant in blocks.rs, imported Vec in boxes.rs, removed an unnecessary enumerate in boxes.rs, allowed dead_code for Weights struct, and adjusted imports in yolox.rs for feature-gated code.
Replaces usage of burn-import with burn-store for loading and saving model weights across llama-burn, mobilenetv2-burn, resnet-burn, and yolox-burn. Updates Cargo.toml dependencies and feature flags, refactors model loading logic to use burn-store's PytorchStore and BurnpackStore, and adapts key remapping APIs. This unifies model serialization and deserialization, and prepares for future burn ecosystem updates.
Adds support for loading model weights from SafeTensors files in addition to PyTorch formats, including key remapping and weight permutation for TinyLlama. Introduces a new 'test_tiny' example for testing TinyLlama with both 'tch-cpu' and 'ndarray' backends. Updates dependencies and features in Cargo.toml, and simplifies the decode method in SentiencePieceTokenizer.
Centralizes HuggingFace-to-Burn tensor key remapping patterns into a single location and applies them via KeyRemapper for both SafeTensors and PyTorch checkpoints. This reduces code duplication and simplifies the addition or modification of key mapping rules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to execute the very first example in the repo and it failed to compile:
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors
Compiling burn-cubecl-fusion v0.20.0-pre.5 (https://github.com/tracel-ai/burn?branch=main#6c9ba5cf)
error[E0412]: cannot find type `Simple` in this scope
--> /home/laggui/.cargo/git/checkouts/burn-6c277d792b0d5d7a/6c9ba5c/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs:172:36
|
172 | if arg.execute_fused::<BT, Simple>(context).is_err() {
| ^^^^^^ not found in this scope
|
help: there is an enum variant `crate::optim::matmul::FusedMatmulSelector::Simple`; try using the variant's enum
|
172 - if arg.execute_fused::<BT, Simple>(context).is_err() {
172 + if arg.execute_fused::<BT, crate::optim::matmul::FusedMatmulSelector>(context).is_err() {
|
help: you might be missing a type parameter
|
133 | impl<R: Runtime, Simple> MatmulOptimization<R> {
That's an issue with burn, but just reinforces the point that we should only have models point to a released version.
/edit: fixed this issue in tracel-ai/burn#4193
llama-burn/src/llama.rs
Outdated
| .map(|snapshot| { | ||
| let path = snapshot.full_path(); | ||
|
|
||
| // Check if this is a wq or wk weight that needs permutation | ||
| if path.contains(".wq.weight") { | ||
| // Permute wq weight | ||
| let data = snapshot.to_data().expect("Failed to get tensor data"); | ||
| let shape = data.shape.clone(); | ||
| let dim1 = shape[0]; | ||
| let dim2 = shape[1]; | ||
|
|
||
| // Create tensor, permute, get data back | ||
| let tensor: Tensor<B, 2> = Tensor::from_data(data, device); | ||
| let permuted = tensor | ||
| .reshape([dim1, n_heads, 2, dim2 / n_heads / 2]) | ||
| .swap_dims(2, 3) | ||
| .reshape([dim1, dim2]); | ||
| let permuted_data = permuted.to_data(); | ||
|
|
||
| TensorSnapshot::from_data( | ||
| permuted_data, | ||
| snapshot.path_stack.clone().unwrap_or_default(), | ||
| snapshot.container_stack.clone().unwrap_or_default(), | ||
| snapshot.tensor_id.unwrap_or_else(ParamId::new), | ||
| ) | ||
| } else if path.contains(".wk.weight") { | ||
| // Permute wk weight | ||
| let data = snapshot.to_data().expect("Failed to get tensor data"); | ||
| let shape = data.shape.clone(); | ||
| let dim1 = shape[0]; | ||
| let wk_dim = d_model * n_kv_heads / n_heads; | ||
|
|
||
| // Create tensor, permute, get data back | ||
| let tensor: Tensor<B, 2> = Tensor::from_data(data, device); | ||
| let permuted = tensor | ||
| .reshape([dim1, n_kv_heads, 2, wk_dim / n_kv_heads / 2]) | ||
| .swap_dims(2, 3) | ||
| .reshape([dim1, wk_dim]); | ||
| let permuted_data = permuted.to_data(); | ||
|
|
||
| TensorSnapshot::from_data( | ||
| permuted_data, | ||
| snapshot.path_stack.clone().unwrap_or_default(), | ||
| snapshot.container_stack.clone().unwrap_or_default(), | ||
| snapshot.tensor_id.unwrap_or_else(ParamId::new), | ||
| ) | ||
| } else { | ||
| // Keep other tensors unchanged | ||
| snapshot | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That feels a lot more verbose and possibly error prone (especially since we have to check for the "path" names) compared to the previous usage 😅
That's a good signal to improve the UX though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed! I've refactored to extract helper functions (permute_rotary_weights, permute_attention_weight) which makes the call site cleaner:
permute_rotary_weights(&mut llama.model, n_heads, n_kv_heads, d_model, device);
The verbosity stems from burn-store's ModuleAdapter not having access to:
- Model config (n_heads, d_model)
- Device (for tensor operations)
A potential future improvement could be a "post-load transformation hook" in burn-store that receives both the snapshot and model context. Would you like me to open an issue for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call site is cleaner since it has been refactored into a function, but it still requires a lot more manipulations then previously. So I think there's still room for improvement for the API!
Extracted the TinyLlama rotary weight permutation code into dedicated helper functions for clarity and maintainability. The new `permute_rotary_weights` function and its helpers handle the permutation of wq/wk tensors, replacing the previous inline logic in `LlamaConfig::load_record`.
Replaced git dependencies on burn and burn-store with version 0.20.0-pre.6 across all projects. Updated resnet-burn finetune example to use the new SupervisedTraining and Learner APIs, reflecting recent changes in the burn training interface.
laggui
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted the llama load changes since the pre-trained weights are still uploaded online in the mpk format.
All models are working with the current PR.
I don't remember what your benchmarks initially showed on loading speed, but the PyTorchFileRecorder (while definitely wasn't fast) appears to load faster than the burn store equivalent. In my tests, it always loaded the llama 3.2 1b pth model in 7s vs 11s for the store.
Summary
Comprehensive migration of burn-models from burn 0.19.1 to burn main branch, replacing burn-import with the new burn-store crate for model weight loading and storage.
Changes by Model
bert-burn
llama-burn
mobilenetv2-burn
resnet-burn
squeezenet-burn
yolox-burn
Key Migration Patterns
Test Plan
NOTES