Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 62 additions & 2 deletions src/backends/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ impl GgufBackend {
let top_k = params.top_k;
let top_p = params.top_p;
let seed = params.seed;
let stop_sequences = params.stop_sequences.clone();

// Perform inference in spawn_blocking since LlamaContext is !Send
let response = tokio::task::spawn_blocking(move || {
Expand Down Expand Up @@ -228,6 +229,7 @@ impl GgufBackend {

// Generate tokens one by one
let mut output_tokens = Vec::new();
let mut generated_text = String::new();
let max_new_tokens = max_tokens as usize;

debug!(
Expand All @@ -239,10 +241,15 @@ impl GgufBackend {
// Get logits for sampling - collect iterator to vec
let candidates_llama: Vec<_> = context.candidates().collect();

// Compute softmax probabilities from raw logits
let logits: Vec<f32> = candidates_llama.iter().map(|c| c.logit()).collect();
let probs = GgufBackend::softmax(&logits);

// Convert LlamaTokenData to our TokenCandidate format
let candidates: Vec<(i32, f32, f32)> = candidates_llama
.iter()
.map(|c| (c.id().0, c.logit(), c.p()))
.zip(probs.iter())
.map(|(c, &p)| (c.id().0, c.logit(), p))
.collect();

// Use configured sampling strategy
Expand All @@ -256,6 +263,19 @@ impl GgufBackend {
break;
}

// Accumulate text and check stop sequences before committing token to output
if !stop_sequences.is_empty() {
if let Ok(tok_str) =
model.token_to_str(LlamaToken(next_token), Special::Tokenize)
{
generated_text.push_str(&tok_str);
if stop_sequences.iter().any(|s| generated_text.contains(s)) {
debug!("Stop sequence matched, stopping generation");
break;
}
}
}

output_tokens.push(next_token);

// Prepare next batch with the sampled token
Expand Down Expand Up @@ -320,6 +340,7 @@ impl GgufBackend {
let top_k = params.top_k;
let top_p = params.top_p;
let seed = params.seed;
let stop_sequences = params.stop_sequences.clone();

// Create streaming channel
let stream_config = StreamConfig {
Expand Down Expand Up @@ -424,6 +445,7 @@ impl GgufBackend {
// Generate tokens and stream them one by one
let max_new_tokens = max_tokens as usize;
let mut sequence = 0u32;
let mut generated_text = String::new();

debug!(
"🔀 Starting streaming token generation with strategy: {:?}, temp: {:.2}",
Expand All @@ -434,10 +456,15 @@ impl GgufBackend {
// Get logits for sampling
let candidates_llama: Vec<_> = context.candidates().collect();

// Compute softmax probabilities from raw logits
let logits: Vec<f32> = candidates_llama.iter().map(|c| c.logit()).collect();
let probs = GgufBackend::softmax(&logits);

// Convert LlamaTokenData to our sampling format
let candidates: Vec<(i32, f32, f32)> = candidates_llama
.iter()
.map(|c| (c.id().0, c.logit(), c.p()))
.zip(probs.iter())
.map(|(c, &p)| (c.id().0, c.logit(), p))
.collect();

// Sample next token
Expand Down Expand Up @@ -466,6 +493,15 @@ impl GgufBackend {
llama_cpp_2::model::Special::Tokenize,
) {
Ok(token_str) => {
// Check stop sequences on accumulated text
if !stop_sequences.is_empty() {
generated_text.push_str(&token_str);
if stop_sequences.iter().any(|s| generated_text.contains(s)) {
debug!("Stop sequence matched, stopping generation");
break;
}
}

let stream_token = StreamToken {
content: token_str.clone(),
sequence,
Expand Down Expand Up @@ -524,6 +560,19 @@ impl GgufBackend {

Ok(Box::pin(result_stream))
}

fn softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum <= 0.0 {
return vec![0.0; logits.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -628,6 +677,7 @@ impl InferenceBackend for GgufBackend {
self.model_info = Some(model_info.clone());

info!("✅ GGUF model loaded successfully with Metal GPU support");

Ok(())
}

Expand All @@ -653,6 +703,11 @@ impl InferenceBackend for GgufBackend {
return Err(InfernoError::Backend("Model not loaded".to_string()).into());
}

// Best-effort: record this inference run in the local model registry
if let Some(info) = &self.model_info {
crate::models::record_model_usage(&info.path).await;
}

let start_time = Instant::now();
info!("Starting GGUF inference");

Expand Down Expand Up @@ -699,6 +754,11 @@ impl InferenceBackend for GgufBackend {
return Err(InfernoError::Backend("Model not loaded".to_string()).into());
}

// Best-effort: record this inference run in the local model registry
if let Some(info) = &self.model_info {
crate::models::record_model_usage(&info.path).await;
}

info!("Starting GGUF streaming inference");
self.generate_stream(input, params).await
}
Expand Down
11 changes: 11 additions & 0 deletions src/backends/onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ impl InferenceBackend for OnnxBackend {
"ONNX model loaded successfully (type: {:?}, GPU: {})",
self.model_type, self.config.gpu_enabled
);

Ok(())
}

Expand Down Expand Up @@ -594,6 +595,11 @@ impl InferenceBackend for OnnxBackend {
return Err(InfernoError::Backend("Model not loaded".to_string()).into());
}

// Best-effort: record this inference run in the local model registry
if let Some(info) = &self.model_info {
crate::models::record_model_usage(&info.path).await;
}

let start_time = Instant::now();
info!("Starting ONNX inference");

Expand Down Expand Up @@ -712,6 +718,11 @@ impl InferenceBackend for OnnxBackend {
return Err(InfernoError::Backend("Model not loaded".to_string()).into());
}

// Best-effort: record this inference run in the local model registry
if let Some(info) = &self.model_info {
crate::models::record_model_usage(&info.path).await;
}

info!("Starting ONNX streaming inference");

let session = self.session.as_ref().unwrap().clone();
Expand Down
Loading
Loading