From 836cbe29e1a8c9945bcf9689bbdd642a566f2701 Mon Sep 17 00:00:00 2001 From: Samuel Cedric Date: Sun, 16 Nov 2025 01:26:26 +0700 Subject: [PATCH] fix: add automatic PyTorch to SafeTensors conversion to prevent HeaderTooLarge (#1) --- src/embedder/mod.rs | 41 ++++++++++++++++++++++++++++++++++-- src/model/downloader.rs | 46 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/src/embedder/mod.rs b/src/embedder/mod.rs index 6cbc39e..0ee585d 100644 --- a/src/embedder/mod.rs +++ b/src/embedder/mod.rs @@ -1,6 +1,6 @@ use crate::error::{Error, Result}; use crate::model::ModelInfo; -use candle_core::{Device, Tensor}; +use candle_core::{pickle, Device, Tensor}; use serde_json::Value; use std::path::PathBuf; use std::sync::Arc; @@ -16,6 +16,8 @@ impl Embedder { pub fn load(model_info: &ModelInfo, device: Device) -> Result { tracing::info!("Loading model from: {:?}", model_info.model_path); + Self::ensure_safetensors_converted(&model_info.model_path)?; + let config_path = model_info.model_path.join("config.json"); let config_content = std::fs::read_to_string(&config_path) .map_err(|e| Error::ModelLoadFailed(format!("Failed to read config: {}", e)))?; @@ -95,7 +97,6 @@ impl Embedder { .map_err(|e| Error::ModelLoadFailed(format!("Failed to load safetensors: {}", e)))? }; - // Get list of tensors and find embedding weight let tensor_list = safetensors.tensors(); let embedding_weight_name = tensor_list .iter() @@ -130,4 +131,40 @@ impl Embedder { pub fn embedding_dim(&self) -> usize { self.embedding_dim } + + fn ensure_safetensors_converted(model_dir: &PathBuf) -> Result<()> { + let pytorch_file = model_dir.join("pytorch_model.bin"); + let safetensors_file = model_dir.join("model.safetensors"); + + if safetensors_file.exists() { + return Ok(()); + } + + if !pytorch_file.exists() { + return Ok(()); + } + + tracing::info!("Converting pytorch_model.bin to model.safetensors..."); + + let tensors_vec = pickle::read_all(&pytorch_file) + .map_err(|e| Error::ModelLoadFailed(format!("Failed to read PyTorch file: {}", e)))?; + + tracing::info!("Loading {} tensors from PyTorch model", tensors_vec.len()); + + let tensors: std::collections::HashMap<_, _> = tensors_vec.into_iter().collect(); + + candle_core::safetensors::save(&tensors, &safetensors_file) + .map_err(|e| Error::ModelLoadFailed(format!("Failed to save SafeTensors: {}", e)))?; + + tracing::info!("✓ Converted to SafeTensors format"); + + // Remove the old PyTorch file to save space + if let Err(e) = std::fs::remove_file(&pytorch_file) { + tracing::warn!("Could not remove pytorch_model.bin: {}", e); + } else { + tracing::info!("Removed pytorch_model.bin to save space"); + } + + Ok(()) + } } diff --git a/src/model/downloader.rs b/src/model/downloader.rs index e450493..7ce9ed8 100644 --- a/src/model/downloader.rs +++ b/src/model/downloader.rs @@ -1,7 +1,9 @@ use crate::config::Config; use crate::error::{Error, Result}; use crate::model::{ModelInfo, ModelRegistry}; +use candle_core::pickle; use hf_hub::api::sync::Api; +use std::path::Path; pub struct ModelDownloader { config: Config, @@ -40,6 +42,9 @@ impl ModelDownloader { .parent() .ok_or_else(|| Error::DownloadFailed("Invalid model path".to_string()))?; + // Auto-convert PyTorch to SafeTensors if needed + Self::ensure_safetensors(model_dir)?; + let name = alias.clone().unwrap_or_else(|| { hf_repo_id .split('/') @@ -64,4 +69,45 @@ impl ModelDownloader { Ok(model_info) } + + fn ensure_safetensors(model_dir: &Path) -> Result<()> { + let pytorch_file = model_dir.join("pytorch_model.bin"); + let safetensors_file = model_dir.join("model.safetensors"); + + // If safetensors exists, we're good + if safetensors_file.exists() { + return Ok(()); + } + + // If pytorch file doesn't exist, nothing to convert + if !pytorch_file.exists() { + return Ok(()); + } + + tracing::info!("Converting pytorch_model.bin to model.safetensors..."); + + // Read PyTorch file and load all tensors + let tensors_vec = pickle::read_all(&pytorch_file) + .map_err(|e| Error::ModelLoadFailed(format!("Failed to read PyTorch file: {}", e)))?; + + tracing::info!("Loading {} tensors from PyTorch model", tensors_vec.len()); + + // Convert to HashMap + let tensors: std::collections::HashMap<_, _> = tensors_vec.into_iter().collect(); + + // Save as safetensors + candle_core::safetensors::save(&tensors, &safetensors_file) + .map_err(|e| Error::ModelLoadFailed(format!("Failed to save SafeTensors: {}", e)))?; + + tracing::info!("✓ Converted to SafeTensors format"); + + // Remove the old PyTorch file to save space + if let Err(e) = std::fs::remove_file(&pytorch_file) { + tracing::warn!("Could not remove pytorch_model.bin: {}", e); + } else { + tracing::info!("Removed pytorch_model.bin to save space"); + } + + Ok(()) + } }