1+ use anyhow:: Result ;
2+ use candle_core:: { Device , Tensor } ;
3+ use mistralrs_core:: { BertEmbeddingModel , MistralRs , MistralRsBuilder , Pipeline , Request , RequestMessage , Response , NormalRequest } ;
4+ use serde:: { Deserialize , Serialize } ;
5+ use serde_json:: Value ;
6+ use std:: sync:: Arc ;
7+ use tokenizers:: Tokenizer ;
8+ use tokio:: sync:: { mpsc, RwLock } ;
9+
10+ #[ derive( Clone ) ]
11+ pub struct EmbeddingEngine {
12+ bert_pipeline : Option < mistralrs_core:: embedding:: bert:: BertPipeline > ,
13+ device : Device ,
14+ }
15+
16+ impl EmbeddingEngine {
17+ pub fn new ( ) -> Result < Self > {
18+ let device = Device :: cuda_if_available ( 0 ) . unwrap_or ( Device :: Cpu ) ;
19+
20+ // Try to load BERT embedding model
21+ let bert_pipeline = match mistralrs_core:: embedding:: bert:: BertPipeline :: new (
22+ BertEmbeddingModel :: SnowflakeArcticEmbedL ,
23+ & device,
24+ ) {
25+ Ok ( pipeline) => {
26+ tracing:: info!( "Loaded BERT embedding model successfully" ) ;
27+ Some ( pipeline)
28+ }
29+ Err ( e) => {
30+ tracing:: warn!( "Failed to load BERT embedding model: {}" , e) ;
31+ None
32+ }
33+ } ;
34+
35+ Ok ( Self {
36+ bert_pipeline,
37+ device,
38+ } )
39+ }
40+
41+ pub async fn generate_embeddings ( & self , texts : Vec < String > ) -> Result < Vec < Vec < f32 > > > {
42+ if let Some ( pipeline) = & self . bert_pipeline {
43+ let mut all_embeddings = Vec :: new ( ) ;
44+
45+ for text in texts {
46+ // Tokenize the text
47+ let encoding = pipeline. tokenizer . encode ( text. clone ( ) , true )
48+ . map_err ( |e| anyhow:: anyhow!( "Tokenization failed: {}" , e) ) ?;
49+
50+ let tokens = encoding. get_ids ( ) ;
51+ let token_ids = Tensor :: new ( tokens, & self . device ) ?
52+ . unsqueeze ( 0 ) ?; // Add batch dimension
53+
54+ // Create token type ids (all zeros for single sequence)
55+ let token_type_ids = Tensor :: zeros_like ( & token_ids) ?;
56+
57+ // Get attention mask (1 for real tokens, 0 for padding)
58+ let attention_mask = Tensor :: ones_like ( & token_ids) ?;
59+
60+ // Forward pass through the model
61+ let output = pipeline. model . forward ( & token_ids, & token_type_ids, Some ( & attention_mask) ) ?;
62+
63+ // Mean pooling over sequence dimension to get sentence embedding
64+ // output shape: [batch_size, seq_len, hidden_size]
65+ let embeddings = self . mean_pooling ( & output, & attention_mask) ?;
66+
67+ // Normalize embeddings
68+ let embeddings = self . normalize_embeddings ( & embeddings) ?;
69+
70+ // Convert to Vec<f32>
71+ let embeddings_vec = embeddings. squeeze ( 0 ) ?. to_vec1 :: < f32 > ( ) ?;
72+ all_embeddings. push ( embeddings_vec) ;
73+ }
74+
75+ Ok ( all_embeddings)
76+ } else {
77+ // Fallback: return placeholder embeddings if model not loaded
78+ // In production, this should return an error instead
79+ tracing:: warn!( "BERT model not loaded, returning placeholder embeddings" ) ;
80+ Ok ( texts. iter ( ) . map ( |_| vec ! [ 0.1 ; 1024 ] ) . collect ( ) )
81+ }
82+ }
83+
84+ fn mean_pooling ( & self , token_embeddings : & Tensor , attention_mask : & Tensor ) -> Result < Tensor > {
85+ // Expand attention_mask to match embeddings dimensions
86+ let attention_mask_expanded = attention_mask. unsqueeze ( 2 ) ?
87+ . expand ( token_embeddings. shape ( ) ) ?
88+ . to_dtype ( token_embeddings. dtype ( ) ) ?;
89+
90+ // Apply attention mask
91+ let sum_embeddings = ( token_embeddings * & attention_mask_expanded) ?
92+ . sum ( 1 ) ?
93+ . unsqueeze ( 1 ) ?;
94+
95+ // Calculate sum of attention mask (avoid division by zero)
96+ let sum_mask = attention_mask_expanded. sum ( 1 ) ?
97+ . clamp ( 1e-9 , f64:: INFINITY ) ?;
98+
99+ // Mean pooling
100+ sum_embeddings. broadcast_div ( & sum_mask)
101+ }
102+
103+ fn normalize_embeddings ( & self , embeddings : & Tensor ) -> Result < Tensor > {
104+ // L2 normalization
105+ let norm = embeddings. sqr ( ) ?
106+ . sum_keepdim ( embeddings. rank ( ) - 1 ) ?
107+ . sqrt ( ) ?
108+ . clamp ( 1e-12 , f64:: INFINITY ) ?;
109+
110+ embeddings. broadcast_div ( & norm)
111+ }
112+ }
113+
114+ #[ derive( Debug , Deserialize ) ]
115+ pub struct EmbeddingRequest {
116+ pub input : EmbeddingInput ,
117+ pub model : Option < String > ,
118+ pub encoding_format : Option < String > ,
119+ pub dimensions : Option < usize > ,
120+ pub user : Option < String > ,
121+ }
122+
123+ #[ derive( Debug , Deserialize ) ]
124+ #[ serde( untagged) ]
125+ pub enum EmbeddingInput {
126+ String ( String ) ,
127+ StringArray ( Vec < String > ) ,
128+ TokenArray ( Vec < u32 > ) ,
129+ TokenArrayArray ( Vec < Vec < u32 > > ) ,
130+ }
131+
132+ impl EmbeddingInput {
133+ pub fn to_string_array ( self ) -> Vec < String > {
134+ match self {
135+ EmbeddingInput :: String ( s) => vec ! [ s] ,
136+ EmbeddingInput :: StringArray ( arr) => arr,
137+ EmbeddingInput :: TokenArray ( tokens) => {
138+ vec ! [ format!( "Token array: {:?}" , tokens) ]
139+ }
140+ EmbeddingInput :: TokenArrayArray ( arrays) => {
141+ arrays. iter ( )
142+ . map ( |tokens| format ! ( "Token array: {:?}" , tokens) )
143+ . collect ( )
144+ }
145+ }
146+ }
147+ }
148+
149+ #[ derive( Debug , Serialize ) ]
150+ pub struct EmbeddingResponse {
151+ pub object : String ,
152+ pub data : Vec < EmbeddingData > ,
153+ pub model : String ,
154+ pub usage : EmbeddingUsage ,
155+ }
156+
157+ #[ derive( Debug , Serialize ) ]
158+ pub struct EmbeddingData {
159+ pub object : String ,
160+ pub index : usize ,
161+ pub embedding : Vec < f32 > ,
162+ }
163+
164+ #[ derive( Debug , Serialize ) ]
165+ pub struct EmbeddingUsage {
166+ pub prompt_tokens : usize ,
167+ pub total_tokens : usize ,
168+ }
0 commit comments