Skip to content

Conversation

@antimora
Copy link
Collaborator

@antimora antimora commented Dec 14, 2025

Summary

Comprehensive migration of burn-models from burn 0.19.1 to burn main branch, replacing burn-import with the new burn-store crate for model weight loading and storage.

Changes by Model

bert-burn

  • Update burn dependency to git main branch
  • Migrate loader from burn-import to burn-store API

llama-burn

  • Update burn/burn-store dependencies to git main branch
  • Replace burn-import with burn-store (PytorchStore, SafetensorsStore, BurnpackStore)
  • Add SafeTensors format support with PyTorchToBurnAdapter for HuggingFace models
  • Fix TinyLlama weight permutation using burn-store's direct tensor access API
  • Refactor key mappings using KeyRemapper::from_patterns() for cleaner code
  • Add ndarray backend feature
  • Add test_tiny.rs example for testing TinyLlama import

mobilenetv2-burn

  • Update burn dependency to git main branch
  • Migrate from NamedMpkFileRecorder to BurnpackStore

resnet-burn

  • Update burn dependency to git main branch
  • Migrate from NamedMpkFileRecorder to BurnpackStore
  • Update inference example to use burn-store API

squeezenet-burn

  • Update burn dependency to git main branch
  • Update build script for burn-store compatibility

yolox-burn

  • Update burn dependency to git main branch
  • Migrate from NamedMpkFileRecorder to BurnpackStore

Key Migration Patterns

Old API New API
burn-import burn-store
PyTorchFileRecorder PytorchStore
NamedMpkFileRecorder BurnpackStore
SafetensorsFileRecorder SafetensorsStore (new)
model.load_file(path, recorder) model.load_from(&mut store)
model.save_file(path, recorder) model.save_into(&mut store)

Test Plan

  • bert-burn compiles
  • llama-burn compiles and TinyLlama generates correct output
  • mobilenetv2-burn inference example runs correctly (predicts "Labrador retriever")
  • resnet-burn inference example runs correctly (predicts "Labrador retriever")
  • squeezenet-burn compiles
  • yolox-burn inference example runs correctly (detects person, bicycle, dog)

NOTES

  1. We still need to update Cargo.toml to point to the published Burn crates
  2. We need to update to the burnpack version (@laggui will do it )

Simplifies tensor loading and model record construction by returning structs directly instead of assigning to intermediate variables. Updates dependency to use the latest 'burn' from the main branch. Cleans up redundant code and improves consistency in function returns.
Switches burn and burn-import dependencies to use the main branch from the GitHub repository. Refactors the Sampler enum to box TopP and adds a constructor for TopP sampling. Updates transformer forward method to accept mutable slices for cache. Cleans up unused code and improves code clarity in llama.rs and chat.rs.
Switched burn and burn-import dependencies to use the latest code from the tracel-ai/burn GitHub repository on the main branch. Updated the inference example to clarify type conversions for index and score extraction. Added example metadata for the inference example in Cargo.toml.
Switched burn and burn-import dependencies to use the latest main branch from GitHub. Updated multi_hot to use IndexingUpdateOp::Add and adjusted training builder to pass LearningStrategy as a parameter. Added clippy allow for large enum variant in ResidualBlock.
Switches Burn, burn-store, and burn-import dependencies to the latest main branch from GitHub. Updates build.rs to remove deprecated RecordType and half_precision usage, and changes the generated weights file extension to .bpk for compatibility with the latest Burn API.
Switched burn and burn-import dependencies to use the latest main branch from GitHub for both regular and dev dependencies. Added an example for inference requiring the pretrained feature. Introduced minor code improvements: allowed large enum variant in blocks.rs, imported Vec in boxes.rs, removed an unnecessary enumerate in boxes.rs, allowed dead_code for Weights struct, and adjusted imports in yolox.rs for feature-gated code.
Replaces usage of burn-import with burn-store for loading and saving model weights across llama-burn, mobilenetv2-burn, resnet-burn, and yolox-burn. Updates Cargo.toml dependencies and feature flags, refactors model loading logic to use burn-store's PytorchStore and BurnpackStore, and adapts key remapping APIs. This unifies model serialization and deserialization, and prepares for future burn ecosystem updates.
Adds support for loading model weights from SafeTensors files in addition to PyTorch formats, including key remapping and weight permutation for TinyLlama. Introduces a new 'test_tiny' example for testing TinyLlama with both 'tch-cpu' and 'ndarray' backends. Updates dependencies and features in Cargo.toml, and simplifies the decode method in SentiencePieceTokenizer.
Centralizes HuggingFace-to-Burn tensor key remapping patterns into a single location and applies them via KeyRemapper for both SafeTensors and PyTorch checkpoints. This reduces code duplication and simplifies the addition or modification of key mapping rules.
@antimora antimora requested a review from laggui December 14, 2025 17:26
@antimora antimora changed the title Migrate to burn-store and add SafeTensors support for HuggingFace model import Migrate burn-models to burn main branch with burn-store Dec 14, 2025
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to execute the very first example in the repo and it failed to compile:

cargo run --example infer-embedding --release --features wgpu,fusion,safetensors
   Compiling burn-cubecl-fusion v0.20.0-pre.5 (https://github.com/tracel-ai/burn?branch=main#6c9ba5cf)
error[E0412]: cannot find type `Simple` in this scope
   --> /home/laggui/.cargo/git/checkouts/burn-6c277d792b0d5d7a/6c9ba5c/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs:172:36
    |
172 |         if arg.execute_fused::<BT, Simple>(context).is_err() {
    |                                    ^^^^^^ not found in this scope
    |
help: there is an enum variant `crate::optim::matmul::FusedMatmulSelector::Simple`; try using the variant's enum
    |
172 -         if arg.execute_fused::<BT, Simple>(context).is_err() {
172 +         if arg.execute_fused::<BT, crate::optim::matmul::FusedMatmulSelector>(context).is_err() {
    |
help: you might be missing a type parameter
    |
133 | impl<R: Runtime, Simple> MatmulOptimization<R> {

That's an issue with burn, but just reinforces the point that we should only have models point to a released version.

/edit: fixed this issue in tracel-ai/burn#4193

Comment on lines 596 to 645
.map(|snapshot| {
let path = snapshot.full_path();

// Check if this is a wq or wk weight that needs permutation
if path.contains(".wq.weight") {
// Permute wq weight
let data = snapshot.to_data().expect("Failed to get tensor data");
let shape = data.shape.clone();
let dim1 = shape[0];
let dim2 = shape[1];

// Create tensor, permute, get data back
let tensor: Tensor<B, 2> = Tensor::from_data(data, device);
let permuted = tensor
.reshape([dim1, n_heads, 2, dim2 / n_heads / 2])
.swap_dims(2, 3)
.reshape([dim1, dim2]);
let permuted_data = permuted.to_data();

TensorSnapshot::from_data(
permuted_data,
snapshot.path_stack.clone().unwrap_or_default(),
snapshot.container_stack.clone().unwrap_or_default(),
snapshot.tensor_id.unwrap_or_else(ParamId::new),
)
} else if path.contains(".wk.weight") {
// Permute wk weight
let data = snapshot.to_data().expect("Failed to get tensor data");
let shape = data.shape.clone();
let dim1 = shape[0];
let wk_dim = d_model * n_kv_heads / n_heads;

// Create tensor, permute, get data back
let tensor: Tensor<B, 2> = Tensor::from_data(data, device);
let permuted = tensor
.reshape([dim1, n_kv_heads, 2, wk_dim / n_kv_heads / 2])
.swap_dims(2, 3)
.reshape([dim1, wk_dim]);
let permuted_data = permuted.to_data();

TensorSnapshot::from_data(
permuted_data,
snapshot.path_stack.clone().unwrap_or_default(),
snapshot.container_stack.clone().unwrap_or_default(),
snapshot.tensor_id.unwrap_or_else(ParamId::new),
)
} else {
// Keep other tensors unchanged
snapshot
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That feels a lot more verbose and possibly error prone (especially since we have to check for the "path" names) compared to the previous usage 😅

That's a good signal to improve the UX though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! I've refactored to extract helper functions (permute_rotary_weights, permute_attention_weight) which makes the call site cleaner:

permute_rotary_weights(&mut llama.model, n_heads, n_kv_heads, d_model, device);

The verbosity stems from burn-store's ModuleAdapter not having access to:

  • Model config (n_heads, d_model)
  • Device (for tensor operations)

A potential future improvement could be a "post-load transformation hook" in burn-store that receives both the snapshot and model context. Would you like me to open an issue for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call site is cleaner since it has been refactored into a function, but it still requires a lot more manipulations then previously. So I think there's still room for improvement for the API!

Extracted the TinyLlama rotary weight permutation code into dedicated helper functions for clarity and maintainability. The new `permute_rotary_weights` function and its helpers handle the permutation of wq/wk tensors, replacing the previous inline logic in `LlamaConfig::load_record`.
Replaced git dependencies on burn and burn-store with version 0.20.0-pre.6 across all projects. Updated resnet-burn finetune example to use the new SupervisedTraining and Learner APIs, reflecting recent changes in the burn training interface.
@antimora antimora requested a review from laggui December 29, 2025 06:15
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted the llama load changes since the pre-trained weights are still uploaded online in the mpk format.

All models are working with the current PR.

I don't remember what your benchmarks initially showed on loading speed, but the PyTorchFileRecorder (while definitely wasn't fast) appears to load faster than the burn store equivalent. In my tests, it always loaded the llama 3.2 1b pth model in 7s vs 11s for the store.

@laggui laggui merged commit a30fbda into tracel-ai:main Jan 16, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants