From 2918b6b9973b0927121fa2dfeb11ae22bd502710 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sat, 10 Jan 2026 13:51:23 -0300 Subject: [PATCH 1/3] feat: Add `from_bytes()` and `from_raw_parts()` to load model from raw data --- src/model.rs | 115 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 86 insertions(+), 29 deletions(-) diff --git a/src/model.rs b/src/model.rs index d695bb9..13b211b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -33,12 +33,10 @@ impl StaticModel { normalize: Option, subfolder: Option<&str>, ) -> Result { - // If provided, set HF token for authenticated downloads if let Some(tok) = token { env::set_var("HF_HUB_TOKEN", tok); } - // Locate tokenizer.json, model.safetensors, config.json let (tok_path, mdl_path, cfg_path) = { let base = repo_or_path.as_ref(); if base.exists() { @@ -61,38 +59,38 @@ impl StaticModel { } }; - // Load the tokenizer - let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; + let tokenizer_bytes = fs::read(&tok_path).context("failed to read tokenizer.json")?; + let safetensors_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; + let config_bytes = fs::read(&cfg_path).context("failed to read config.json")?; - // Median-token-length hack for pre-truncation - let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); - lens.sort_unstable(); - let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1); + Self::from_bytes(&tokenizer_bytes, &safetensors_bytes, &config_bytes, normalize) + } + + /// Load a Model2Vec model from raw bytes. + /// + /// # Arguments + /// * `tokenizer_bytes` - Contents of tokenizer.json + /// * `safetensors_bytes` - Contents of model.safetensors + /// * `config_bytes` - Contents of config.json + /// * `normalize` - Optional flag to override normalization (default from config) + pub fn from_bytes( + tokenizer_bytes: &[u8], + safetensors_bytes: &[u8], + config_bytes: &[u8], + normalize: Option, + ) -> Result { + let tokenizer = + Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; + + let median_token_length = Self::median_token_length(&tokenizer); + let unk_token_id = Self::unk_token_id(&tokenizer)?; - // Read normalize default from config.json - let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?; - let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?; + let cfg: Value = serde_json::from_slice(config_bytes).context("failed to parse config.json")?; let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); let normalize = normalize.unwrap_or(cfg_norm); - // Serialize the tokenizer to JSON, then parse it and get the unk_token - let spec_json = tokenizer - .to_string(false) - .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; - let spec: Value = serde_json::from_str(&spec_json)?; - let unk_token = spec - .get("model") - .and_then(|m| m.get("unk_token")) - .and_then(Value::as_str) - .unwrap_or("[UNK]"); - let unk_token_id = tokenizer - .token_to_id(unk_token) - .ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))? - as usize; - // Load the safetensors - let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; - let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?; + let safet = SafeTensors::deserialize(safetensors_bytes).context("failed to parse safetensors")?; let tensor = safet .tensor("embeddings") .or_else(|_| safet.tensor("0")) @@ -161,10 +159,69 @@ impl StaticModel { token_mapping, normalize, median_token_length, - unk_token_id: Some(unk_token_id), + unk_token_id, }) } + /// Construct from pre-parsed parts. + /// + /// # Arguments + /// * `tokenizer` - Pre-deserialized tokenizer + /// * `embeddings` - Raw f32 embedding data + /// * `rows` - Number of vocabulary entries + /// * `cols` - Embedding dimension + /// * `normalize` - Whether to L2-normalize output embeddings + pub fn from_raw_parts( + tokenizer: Tokenizer, + embeddings: &[f32], + rows: usize, + cols: usize, + normalize: bool, + ) -> Result { + if embeddings.len() != rows * cols { + return Err(anyhow!( + "embeddings length {} != rows {} * cols {}", + embeddings.len(), + rows, + cols + )); + } + + let median_token_length = Self::median_token_length(&tokenizer); + let unk_token_id = Self::unk_token_id(&tokenizer)?; + + let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) + .context("failed to build embeddings array")?; + + Ok(Self { + tokenizer, + embeddings, + weights: None, + token_mapping: None, + normalize, + median_token_length, + unk_token_id, + }) + } + + fn median_token_length(tokenizer: &Tokenizer) -> usize { + let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); + lens.sort_unstable(); + lens.get(lens.len() / 2).copied().unwrap_or(1) + } + + fn unk_token_id(tokenizer: &Tokenizer) -> Result> { + let spec_json = tokenizer + .to_string(false) + .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; + let spec: Value = serde_json::from_str(&spec_json)?; + let unk_token = spec + .get("model") + .and_then(|m| m.get("unk_token")) + .and_then(Value::as_str); + Ok(unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize)) + } + /// Char-level truncation to max_tokens * median_token_length fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str { let max_chars = max_tokens.saturating_mul(median_len); From a08059eb849496050b0103655f0e92c5ce19815a Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sat, 10 Jan 2026 14:36:28 -0300 Subject: [PATCH 2/3] chore: Add test for `from_raw_parts` --- src/model.rs | 1 + tests/test_model.rs | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/model.rs b/src/model.rs index 13b211b..ced76a2 100644 --- a/src/model.rs +++ b/src/model.rs @@ -171,6 +171,7 @@ impl StaticModel { /// * `rows` - Number of vocabulary entries /// * `cols` - Embedding dimension /// * `normalize` - Whether to L2-normalize output embeddings + #[allow(dead_code)] pub fn from_raw_parts( tokenizer: Tokenizer, embeddings: &[f32], diff --git a/tests/test_model.rs b/tests/test_model.rs index f09b8c2..3fbb075 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -70,3 +70,26 @@ fn test_normalization_flag_override() { "Without normalization override, norm should be larger" ); } + +/// Test from_raw_parts constructor +#[test] +fn test_from_raw_parts() { + use std::fs; + use tokenizers::Tokenizer; + use safetensors::SafeTensors; + + let path = "tests/fixtures/test-model-float32"; + let tokenizer = Tokenizer::from_file(format!("{path}/tokenizer.json")).unwrap(); + let bytes = fs::read(format!("{path}/model.safetensors")).unwrap(); + let tensors = SafeTensors::deserialize(&bytes).unwrap(); + let tensor = tensors.tensor("embeddings").unwrap(); + let [rows, cols]: [usize; 2] = tensor.shape().try_into().unwrap(); + let floats: Vec = tensor.data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes(b.try_into().unwrap())) + .collect(); + + let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true).unwrap(); + let emb = model.encode_single("hello"); + assert!(!emb.is_empty()); +} From e992ae04d7769d8deb05f454b86a277f09193030 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sat, 10 Jan 2026 14:49:44 -0300 Subject: [PATCH 3/3] feat: Add `from_raw_parts()` constructor - `from_pretrained` now delegates to `from_raw_parts` - Fixes BPE tokenizer support (unk_token_id now optional) --- src/model.rs | 91 +++++++++++++++------------------------------ tests/test_model.rs | 2 +- 2 files changed, 32 insertions(+), 61 deletions(-) diff --git a/src/model.rs b/src/model.rs index ced76a2..dff1a8a 100644 --- a/src/model.rs +++ b/src/model.rs @@ -33,10 +33,12 @@ impl StaticModel { normalize: Option, subfolder: Option<&str>, ) -> Result { + // If provided, set HF token for authenticated downloads if let Some(tok) = token { env::set_var("HF_HUB_TOKEN", tok); } + // Locate tokenizer.json, model.safetensors, config.json let (tok_path, mdl_path, cfg_path) = { let base = repo_or_path.as_ref(); if base.exists() { @@ -59,38 +61,18 @@ impl StaticModel { } }; - let tokenizer_bytes = fs::read(&tok_path).context("failed to read tokenizer.json")?; - let safetensors_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; - let config_bytes = fs::read(&cfg_path).context("failed to read config.json")?; + // Load the tokenizer + let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; - Self::from_bytes(&tokenizer_bytes, &safetensors_bytes, &config_bytes, normalize) - } - - /// Load a Model2Vec model from raw bytes. - /// - /// # Arguments - /// * `tokenizer_bytes` - Contents of tokenizer.json - /// * `safetensors_bytes` - Contents of model.safetensors - /// * `config_bytes` - Contents of config.json - /// * `normalize` - Optional flag to override normalization (default from config) - pub fn from_bytes( - tokenizer_bytes: &[u8], - safetensors_bytes: &[u8], - config_bytes: &[u8], - normalize: Option, - ) -> Result { - let tokenizer = - Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; - - let median_token_length = Self::median_token_length(&tokenizer); - let unk_token_id = Self::unk_token_id(&tokenizer)?; - - let cfg: Value = serde_json::from_slice(config_bytes).context("failed to parse config.json")?; + // Read normalize default from config.json + let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?; + let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?; let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); let normalize = normalize.unwrap_or(cfg_norm); // Load the safetensors - let safet = SafeTensors::deserialize(safetensors_bytes).context("failed to parse safetensors")?; + let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; + let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?; let tensor = safet .tensor("embeddings") .or_else(|_| safet.tensor("0")) @@ -113,7 +95,6 @@ impl StaticModel { Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(), other => return Err(anyhow!("unsupported tensor dtype: {other:?}")), }; - let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?; // Load optional weights for vocabulary quantization let weights = match safet.tensor("weights") { @@ -152,15 +133,7 @@ impl StaticModel { Err(_) => None, }; - Ok(Self { - tokenizer, - embeddings, - weights, - token_mapping, - normalize, - median_token_length, - unk_token_id, - }) + Self::from_raw_parts(tokenizer, &floats, rows, cols, normalize, weights, token_mapping) } /// Construct from pre-parsed parts. @@ -171,13 +144,16 @@ impl StaticModel { /// * `rows` - Number of vocabulary entries /// * `cols` - Embedding dimension /// * `normalize` - Whether to L2-normalize output embeddings - #[allow(dead_code)] + /// * `weights` - Optional per-token weights for quantized models + /// * `token_mapping` - Optional token ID mapping for quantized models pub fn from_raw_parts( tokenizer: Tokenizer, embeddings: &[f32], rows: usize, cols: usize, normalize: bool, + weights: Option>, + token_mapping: Option>, ) -> Result { if embeddings.len() != rows * cols { return Err(anyhow!( @@ -188,30 +164,12 @@ impl StaticModel { )); } - let median_token_length = Self::median_token_length(&tokenizer); - let unk_token_id = Self::unk_token_id(&tokenizer)?; - - let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) - .context("failed to build embeddings array")?; - - Ok(Self { - tokenizer, - embeddings, - weights: None, - token_mapping: None, - normalize, - median_token_length, - unk_token_id, - }) - } - - fn median_token_length(tokenizer: &Tokenizer) -> usize { + // Median-token-length hack for pre-truncation let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); lens.sort_unstable(); - lens.get(lens.len() / 2).copied().unwrap_or(1) - } + let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1); - fn unk_token_id(tokenizer: &Tokenizer) -> Result> { + // Get unk_token from tokenizer (optional - BPE tokenizers may not have one) let spec_json = tokenizer .to_string(false) .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; @@ -220,7 +178,20 @@ impl StaticModel { .get("model") .and_then(|m| m.get("unk_token")) .and_then(Value::as_str); - Ok(unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize)) + let unk_token_id = unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize); + + let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) + .context("failed to build embeddings array")?; + + Ok(Self { + tokenizer, + embeddings, + weights, + token_mapping, + normalize, + median_token_length, + unk_token_id, + }) } /// Char-level truncation to max_tokens * median_token_length diff --git a/tests/test_model.rs b/tests/test_model.rs index 3fbb075..abb3761 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -89,7 +89,7 @@ fn test_from_raw_parts() { .map(|b| f32::from_le_bytes(b.try_into().unwrap())) .collect(); - let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true).unwrap(); + let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true, None, None).unwrap(); let emb = model.encode_single("hello"); assert!(!emb.is_empty()); }