Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions src/embedder/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::error::{Error, Result};
use crate::model::ModelInfo;
use candle_core::{Device, Tensor};
use candle_core::{pickle, Device, Tensor};
use serde_json::Value;
use std::path::PathBuf;
use std::sync::Arc;
Expand All @@ -16,6 +16,8 @@
pub fn load(model_info: &ModelInfo, device: Device) -> Result<Self> {
tracing::info!("Loading model from: {:?}", model_info.model_path);

Self::ensure_safetensors_converted(&model_info.model_path)?;

let config_path = model_info.model_path.join("config.json");
let config_content = std::fs::read_to_string(&config_path)
.map_err(|e| Error::ModelLoadFailed(format!("Failed to read config: {}", e)))?;
Expand Down Expand Up @@ -95,7 +97,6 @@
.map_err(|e| Error::ModelLoadFailed(format!("Failed to load safetensors: {}", e)))?
};

// Get list of tensors and find embedding weight
let tensor_list = safetensors.tensors();
let embedding_weight_name = tensor_list
.iter()
Expand Down Expand Up @@ -130,4 +131,40 @@
pub fn embedding_dim(&self) -> usize {
self.embedding_dim
}

fn ensure_safetensors_converted(model_dir: &PathBuf) -> Result<()> {

Check failure on line 135 in src/embedder/mod.rs

View workflow job for this annotation

GitHub Actions / test

writing `&PathBuf` instead of `&Path` involves a new object where a slice will do
let pytorch_file = model_dir.join("pytorch_model.bin");
let safetensors_file = model_dir.join("model.safetensors");

if safetensors_file.exists() {
return Ok(());
}

if !pytorch_file.exists() {
return Ok(());
}

tracing::info!("Converting pytorch_model.bin to model.safetensors...");

let tensors_vec = pickle::read_all(&pytorch_file)
.map_err(|e| Error::ModelLoadFailed(format!("Failed to read PyTorch file: {}", e)))?;

tracing::info!("Loading {} tensors from PyTorch model", tensors_vec.len());

let tensors: std::collections::HashMap<_, _> = tensors_vec.into_iter().collect();

candle_core::safetensors::save(&tensors, &safetensors_file)
.map_err(|e| Error::ModelLoadFailed(format!("Failed to save SafeTensors: {}", e)))?;

tracing::info!("✓ Converted to SafeTensors format");

// Remove the old PyTorch file to save space
if let Err(e) = std::fs::remove_file(&pytorch_file) {
tracing::warn!("Could not remove pytorch_model.bin: {}", e);
} else {
tracing::info!("Removed pytorch_model.bin to save space");
}

Ok(())
}
}
46 changes: 46 additions & 0 deletions src/model/downloader.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::config::Config;
use crate::error::{Error, Result};
use crate::model::{ModelInfo, ModelRegistry};
use candle_core::pickle;
use hf_hub::api::sync::Api;
use std::path::Path;

pub struct ModelDownloader {
config: Config,
Expand Down Expand Up @@ -40,6 +42,9 @@ impl ModelDownloader {
.parent()
.ok_or_else(|| Error::DownloadFailed("Invalid model path".to_string()))?;

// Auto-convert PyTorch to SafeTensors if needed
Self::ensure_safetensors(model_dir)?;

let name = alias.clone().unwrap_or_else(|| {
hf_repo_id
.split('/')
Expand All @@ -64,4 +69,45 @@ impl ModelDownloader {

Ok(model_info)
}

fn ensure_safetensors(model_dir: &Path) -> Result<()> {
let pytorch_file = model_dir.join("pytorch_model.bin");
let safetensors_file = model_dir.join("model.safetensors");

// If safetensors exists, we're good
if safetensors_file.exists() {
return Ok(());
}

// If pytorch file doesn't exist, nothing to convert
if !pytorch_file.exists() {
return Ok(());
}

tracing::info!("Converting pytorch_model.bin to model.safetensors...");

// Read PyTorch file and load all tensors
let tensors_vec = pickle::read_all(&pytorch_file)
.map_err(|e| Error::ModelLoadFailed(format!("Failed to read PyTorch file: {}", e)))?;

tracing::info!("Loading {} tensors from PyTorch model", tensors_vec.len());

// Convert to HashMap
let tensors: std::collections::HashMap<_, _> = tensors_vec.into_iter().collect();

// Save as safetensors
candle_core::safetensors::save(&tensors, &safetensors_file)
.map_err(|e| Error::ModelLoadFailed(format!("Failed to save SafeTensors: {}", e)))?;

tracing::info!("✓ Converted to SafeTensors format");

// Remove the old PyTorch file to save space
if let Err(e) = std::fs::remove_file(&pytorch_file) {
tracing::warn!("Could not remove pytorch_model.bin: {}", e);
} else {
tracing::info!("Removed pytorch_model.bin to save space");
}

Ok(())
}
}
Loading