diff --git a/Cargo.lock b/Cargo.lock index 5fc9434f..deb2f87f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5532,6 +5532,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "stakai", "stakpak-shared", "tempfile", "tokio", @@ -5664,6 +5665,7 @@ dependencies = [ "serde_json", "shlex", "similar", + "stakai", "stakpak-api", "stakpak-shared", "syntect", diff --git a/cli/src/commands/acp/server.rs b/cli/src/commands/acp/server.rs index 2584ec01..5e12ea1b 100644 --- a/cli/src/commands/acp/server.rs +++ b/cli/src/commands/acp/server.rs @@ -4,12 +4,13 @@ use agent_client_protocol::{self as acp, Client as AcpClient, SessionNotificatio use futures_util::StreamExt; use stakpak_api::models::ApiStreamError; use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, StakpakConfig}; +use stakpak_api::{Model, ModelLimit}; use stakpak_mcp_client::McpClient; use stakpak_shared::models::integrations::mcp::CallToolResultExt; use stakpak_shared::models::integrations::openai::{ - AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamResponse, - ChatMessage, FinishReason, FunctionCall, FunctionCallDelta, MessageContent, Role, Tool, - ToolCall, ToolCallResultProgress, ToolCallResultStatus, + ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, + FinishReason, FunctionCall, FunctionCallDelta, MessageContent, Role, Tool, ToolCall, + ToolCallResultProgress, ToolCallResultStatus, }; use stakpak_shared::models::llm::LLMTokenUsage; use std::cell::Cell; @@ -22,6 +23,8 @@ use uuid::Uuid; pub struct StakpakAcpAgent { config: AppConfig, client: Arc, + /// Default model to use for chat completions + model: Model, session_update_tx: mpsc::UnboundedSender<(acp::SessionNotification, oneshot::Sender<()>)>, next_session_id: Cell, mcp_client: Option>, @@ -86,6 +89,37 @@ impl StakpakAcpAgent { Arc::new(client) }; + // Get default model - use smart_model from config or first available model + let model = if let Some(smart_model_str) = &config.smart_model { + // Parse the smart_model string to determine provider + let provider = if smart_model_str.starts_with("anthropic/") + || smart_model_str.contains("claude") + { + "anthropic" + } else if smart_model_str.starts_with("openai/") || smart_model_str.contains("gpt") { + "openai" + } else if smart_model_str.starts_with("google/") || smart_model_str.contains("gemini") { + "google" + } else { + "stakpak" + }; + Model::custom(smart_model_str.clone(), provider) + } else { + // Use first available model from client + let models = client.list_models().await; + models.into_iter().next().unwrap_or_else(|| { + // Fallback default: Claude Opus via Stakpak + Model::new( + "anthropic/claude-opus-4-5-20251101", + "Claude Opus 4.5", + "stakpak", + true, + None, + ModelLimit::default(), + ) + }) + }; + // Initialize MCP client and tools (optional for ACP) let (mcp_client, mcp_tools, tools) = match Self::initialize_mcp_server_and_tools(&config).await { @@ -114,6 +148,7 @@ impl StakpakAcpAgent { Ok(Self { config, client, + model, session_update_tx, next_session_id: Cell::new(0), mcp_client, @@ -1081,7 +1116,7 @@ impl StakpakAcpAgent { id: "".to_string(), object: "".to_string(), created: 0, - model: AgentModel::Smart.to_string(), + model: "claude-sonnet-4-5".to_string(), choices: vec![], usage: LLMTokenUsage { prompt_tokens: 0, @@ -1372,6 +1407,7 @@ impl StakpakAcpAgent { let agent = StakpakAcpAgent { config: self.config.clone(), client: self.client.clone(), + model: self.model.clone(), session_update_tx: tx.clone(), next_session_id: self.next_session_id.clone(), mcp_client, @@ -1492,6 +1528,7 @@ impl Clone for StakpakAcpAgent { Self { config: self.config.clone(), client: self.client.clone(), + model: self.model.clone(), session_update_tx: self.session_update_tx.clone(), next_session_id: Cell::new(self.next_session_id.get()), mcp_client: self.mcp_client.clone(), @@ -1736,7 +1773,7 @@ impl acp::Agent for StakpakAcpAgent { let (stream, _request_id) = self .client - .chat_completion_stream(AgentModel::Smart, messages, tools_option.clone(), None) + .chat_completion_stream(self.model.clone(), messages, tools_option.clone(), None) .await .map_err(|e| { log::error!("Chat completion stream failed: {e}"); @@ -1932,7 +1969,7 @@ impl acp::Agent for StakpakAcpAgent { let (follow_up_stream, _request_id) = self .client .chat_completion_stream( - AgentModel::Smart, + self.model.clone(), current_messages.clone(), tools_option.clone(), None, diff --git a/cli/src/commands/agent/run/mode_async.rs b/cli/src/commands/agent/run/mode_async.rs index f4293a11..82a783da 100644 --- a/cli/src/commands/agent/run/mode_async.rs +++ b/cli/src/commands/agent/run/mode_async.rs @@ -11,10 +11,10 @@ use crate::commands::agent::run::tooling::run_tool_call; use crate::config::AppConfig; use crate::utils::agents_md::AgentsMdInfo; use crate::utils::local_context::LocalContext; -use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, models::ListRuleBook}; +use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, Model, models::ListRuleBook}; use stakpak_mcp_server::EnabledToolsConfig; use stakpak_shared::local_store::LocalStore; -use stakpak_shared::models::integrations::openai::{AgentModel, ChatMessage}; +use stakpak_shared::models::integrations::openai::ChatMessage; use stakpak_shared::models::llm::LLMTokenUsage; use stakpak_shared::models::subagent::SubagentConfigs; use std::time::Instant; @@ -35,7 +35,7 @@ pub struct RunAsyncConfig { pub enable_mtls: bool, pub system_prompt: Option, pub enabled_tools: EnabledToolsConfig, - pub model: AgentModel, + pub model: Model, pub agents_md: Option, } diff --git a/cli/src/commands/agent/run/mode_interactive.rs b/cli/src/commands/agent/run/mode_interactive.rs index 80a8d78f..070ec28d 100644 --- a/cli/src/commands/agent/run/mode_interactive.rs +++ b/cli/src/commands/agent/run/mode_interactive.rs @@ -19,12 +19,12 @@ use crate::utils::check_update::get_latest_cli_version; use crate::utils::local_context::LocalContext; use reqwest::header::HeaderMap; use stakpak_api::models::ApiStreamError; -use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, models::ListRuleBook}; +use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, Model, models::ListRuleBook}; use stakpak_mcp_server::EnabledToolsConfig; use stakpak_shared::models::integrations::mcp::CallToolResultExt; use stakpak_shared::models::integrations::openai::{ - AgentModel, ChatMessage, MessageContent, Role, ToolCall, ToolCallResultStatus, + ChatMessage, MessageContent, Role, ToolCall, ToolCallResultStatus, }; use stakpak_shared::models::llm::{LLMTokenUsage, PromptTokensDetails}; use stakpak_shared::models::subagent::SubagentConfigs; @@ -57,7 +57,7 @@ pub struct RunInteractiveConfig { pub allowed_tools: Option>, pub auto_approve: Option>, pub enabled_tools: EnabledToolsConfig, - pub model: AgentModel, + pub model: Model, pub agents_md: Option, } @@ -118,8 +118,8 @@ pub async fn run_interactive( }); let editor_command = ctx.editor.clone(); - let model_clone = model.clone(); let auth_display_info_for_tui = ctx.get_auth_display_info(); + let model_for_tui = model.clone(); let tui_handle = tokio::spawn(async move { let latest_version = get_latest_cli_version().await; stakpak_tui::run_tui( @@ -135,7 +135,7 @@ pub async fn run_interactive( allowed_tools.as_ref(), current_profile_for_tui, rulebook_config_for_tui, - model_clone, + model_for_tui, editor_command, auth_display_info_for_tui, ) @@ -288,7 +288,7 @@ pub async fn run_interactive( while let Some(output_event) = output_rx.recv().await { match output_event { - OutputEvent::SwitchModel(new_model) => { + OutputEvent::SwitchToModel(new_model) => { model = new_model; continue; } @@ -814,6 +814,16 @@ pub async fn run_interactive( .await?; continue; } + OutputEvent::RequestAvailableModels => { + // Load available models from the provider registry + let available_models = client.list_models().await; + send_input_event( + &input_tx, + InputEvent::AvailableModelsLoaded(available_models), + ) + .await?; + continue; + } } let headers = if study_mode { diff --git a/cli/src/config/app.rs b/cli/src/config/app.rs index 44b541f7..9aa8df24 100644 --- a/cli/src/config/app.rs +++ b/cli/src/config/app.rs @@ -54,6 +54,8 @@ pub struct AppConfig { pub eco_model: Option, /// Recovery model name pub recovery_model: Option, + /// New unified model field (replaces smart/eco/recovery model selection) + pub model: Option, /// Unique ID for anonymous telemetry pub anonymous_id: Option, /// Whether to collect telemetry data @@ -162,6 +164,7 @@ impl AppConfig { smart_model: profile_config.smart_model, eco_model: profile_config.eco_model, recovery_model: profile_config.recovery_model, + model: profile_config.model, anonymous_id: settings.anonymous_id, collect_telemetry: settings.collect_telemetry, editor: settings.editor, @@ -776,6 +779,47 @@ impl AppConfig { (config_provider, None, None) } + + /// Get the default Model from config + /// + /// Uses the `model` field if set, otherwise falls back to `smart_model`, + /// and finally to a default Claude Opus model. + /// + /// Searches the model catalog by ID. If the model string has a provider + /// prefix (e.g., "anthropic/claude-opus-4-5"), it searches within that + /// provider first. Otherwise, it searches all providers. + pub fn get_default_model(&self) -> stakpak_api::Model { + let use_stakpak = self.api_key.is_some(); + + // Priority: model > smart_model > default + let model_str = self + .model + .as_ref() + .or(self.smart_model.as_ref()) + .map(|s| s.as_str()) + .unwrap_or("claude-opus-4-5-20251101"); + + // Search the model catalog + stakpak_api::find_model(model_str, use_stakpak).unwrap_or_else(|| { + // Model not found in catalog - create a custom model + // Extract provider from prefix if present + let (provider, model_id) = if let Some(idx) = model_str.find('/') { + let (prefix, rest) = model_str.split_at(idx); + (prefix, &rest[1..]) + } else { + ("anthropic", model_str) // Default to anthropic + }; + + let final_provider = if use_stakpak { "stakpak" } else { provider }; + let final_id = if use_stakpak { + format!("{}/{}", provider, model_id) + } else { + model_id.to_string() + }; + + stakpak_api::Model::custom(final_id, final_provider) + }) + } } // Conversions @@ -810,6 +854,7 @@ impl From for ProfileConfig { eco_model: config.eco_model, smart_model: config.smart_model, recovery_model: config.recovery_model, + model: config.model, } } } diff --git a/cli/src/config/profile.rs b/cli/src/config/profile.rs index 1d7d2474..67dd32ce 100644 --- a/cli/src/config/profile.rs +++ b/cli/src/config/profile.rs @@ -71,11 +71,22 @@ pub struct ProfileConfig { #[serde(skip_serializing_if = "Option::is_none")] pub anthropic: Option, - /// Eco (fast/cheap) model name + /// User's preferred model (replaces smart_model/eco_model/recovery_model) + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + // ========================================================================= + // Legacy model fields - kept for backward compatibility during migration + // These are read but deprecated (will migrate to 'model' field) + // ========================================================================= + /// Eco (fast/cheap) model name (deprecated - use 'model') + #[serde(skip_serializing_if = "Option::is_none")] pub eco_model: Option, - /// Smart (capable) model name + /// Smart (capable) model name (deprecated - use 'model') + #[serde(skip_serializing_if = "Option::is_none")] pub smart_model: Option, - /// Recovery model name + /// Recovery model name (deprecated - use 'model') + #[serde(skip_serializing_if = "Option::is_none")] pub recovery_model: Option, } @@ -304,6 +315,12 @@ impl ProfileConfig { .gemini .clone() .or_else(|| other.and_then(|config| config.gemini.clone())), + // New unified model field + model: self + .model + .clone() + .or_else(|| other.and_then(|config| config.model.clone())), + // Legacy fields - merge for backward compatibility during transition eco_model: self .eco_model .clone() diff --git a/cli/src/config/tests.rs b/cli/src/config/tests.rs index f6dd7f25..c5d50ca1 100644 --- a/cli/src/config/tests.rs +++ b/cli/src/config/tests.rs @@ -62,6 +62,7 @@ fn sample_app_config(profile_name: &str) -> AppConfig { smart_model: None, eco_model: None, recovery_model: None, + model: None, anonymous_id: Some("test-user-id".into()), collect_telemetry: Some(true), editor: Some("nano".into()), @@ -430,6 +431,7 @@ fn save_writes_profile_and_settings() { smart_model: None, eco_model: None, recovery_model: None, + model: None, anonymous_id: Some("test-user-id".into()), collect_telemetry: Some(true), editor: Some("nano".into()), diff --git a/cli/src/main.rs b/cli/src/main.rs index 553bab4f..c5147c94 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -3,7 +3,7 @@ use names::{self, Name}; use rustls::crypto::CryptoProvider; use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider}; use stakpak_mcp_server::EnabledToolsConfig; -use stakpak_shared::models::{integrations::openai::AgentModel, subagent::SubagentConfigs}; +use stakpak_shared::models::subagent::SubagentConfigs; use std::{ env, path::{Path, PathBuf}, @@ -122,10 +122,6 @@ struct Cli { #[arg(long = "profile")] profile: Option, - /// Choose agent model on startup (smart or eco) - #[arg(long = "model")] - model: Option, - /// Custom path to config file (overrides default ~/.stakpak/config.toml) #[arg(long = "config")] config_path: Option, @@ -430,6 +426,7 @@ async fn main() { let allowed_tools = cli.allowed_tools.or_else(|| config.allowed_tools.clone()); let auto_approve = config.auto_approve.clone(); + let default_model = config.get_default_model(); match use_async_mode { // Async mode: run continuously until no more tool calls (or max_steps=1 for single-step) @@ -452,7 +449,7 @@ async fn main() { enabled_tools: EnabledToolsConfig { slack: cli.enable_slack_tools, }, - model: cli.model.unwrap_or(AgentModel::Smart), + model: default_model.clone(), agents_md: agents_md.clone(), }, ) @@ -484,7 +481,7 @@ async fn main() { enabled_tools: EnabledToolsConfig { slack: cli.enable_slack_tools, }, - model: cli.model.unwrap_or(AgentModel::Smart), + model: default_model, agents_md, }, ) diff --git a/libs/ai/examples/anthropic_generate.rs b/libs/ai/examples/anthropic_generate.rs index 0d15aee4..9a0cddc9 100644 --- a/libs/ai/examples/anthropic_generate.rs +++ b/libs/ai/examples/anthropic_generate.rs @@ -1,6 +1,6 @@ //! Example: Basic Anthropic generation -use stakai::{GenerateRequest, Inference, Message, Role}; +use stakai::{GenerateRequest, Inference, Message, Model, Role}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -9,7 +9,7 @@ async fn main() -> Result<(), Box> { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-3-5-sonnet-20241022", + Model::custom("claude-3-5-sonnet-20241022", "anthropic"), vec![ Message::new(Role::System, "You are a helpful AI assistant."), Message::new(Role::User, "Explain quantum computing in simple terms."), diff --git a/libs/ai/examples/anthropic_stream.rs b/libs/ai/examples/anthropic_stream.rs index 769df3cf..6cf5b059 100644 --- a/libs/ai/examples/anthropic_stream.rs +++ b/libs/ai/examples/anthropic_stream.rs @@ -1,7 +1,7 @@ //! Example: Streaming with Anthropic use futures::StreamExt; -use stakai::{GenerateRequest, Inference, Message, Role, StreamEvent}; +use stakai::{GenerateRequest, Inference, Message, Model, Role, StreamEvent}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -10,7 +10,7 @@ async fn main() -> Result<(), Box> { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-3-5-sonnet-20241022", + Model::custom("claude-3-5-sonnet-20241022", "anthropic"), vec![Message::new( Role::User, "Write a short poem about Rust programming.", diff --git a/libs/ai/examples/anthropic_test.rs b/libs/ai/examples/anthropic_test.rs index 4949c1a3..4a5d0db5 100644 --- a/libs/ai/examples/anthropic_test.rs +++ b/libs/ai/examples/anthropic_test.rs @@ -1,6 +1,6 @@ //! Test Anthropic provider -use stakai::{GenerateRequest, Inference, Message, Role}; +use stakai::{GenerateRequest, Inference, Message, Model, Role}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -9,7 +9,7 @@ async fn main() -> Result<(), Box> { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-3-5-sonnet-20241022", + Model::custom("claude-3-5-sonnet-20241022", "anthropic"), vec![Message::new( Role::User, "What is the capital of France? Answer in one word.", diff --git a/libs/ai/examples/basic_generate.rs b/libs/ai/examples/basic_generate.rs index 9142ca27..a2ce4459 100644 --- a/libs/ai/examples/basic_generate.rs +++ b/libs/ai/examples/basic_generate.rs @@ -1,6 +1,6 @@ //! Basic generation example -use stakai::{GenerateRequest, Inference, Message, Role}; +use stakai::{GenerateRequest, Inference, Message, Model, Role}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -9,7 +9,7 @@ async fn main() -> Result<(), Box> { // Build a request let mut request = GenerateRequest::new( - "gpt-4", + Model::custom("gpt-4", "openai"), vec![ Message::new(Role::System, "You are a helpful assistant"), Message::new(Role::User, "What is Rust programming language?"), diff --git a/libs/ai/examples/config_test.rs b/libs/ai/examples/config_test.rs index 10efae60..8ae3a2e0 100644 --- a/libs/ai/examples/config_test.rs +++ b/libs/ai/examples/config_test.rs @@ -1,6 +1,6 @@ //! Test example for InferenceConfig -use stakai::{GenerateRequest, Inference, InferenceConfig, Message, Role}; +use stakai::{GenerateRequest, Inference, InferenceConfig, Message, Model, Role}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -48,10 +48,13 @@ async fn main() -> Result<(), Box> { // Test 4: Try to make a request (will fail without real API key, but tests the API) println!("\n=== Test 4: Request API Test ==="); - let request = GenerateRequest::new("gpt-4", vec![Message::new(Role::User, "Hello!")]); + let request = GenerateRequest::new( + Model::custom("gpt-4", "openai"), + vec![Message::new(Role::User, "Hello!")], + ); println!("āœ“ Request created successfully"); - println!(" Model: {}", request.model); + println!(" Model: {}", request.model.id); println!(" Messages: {}", request.messages.len()); println!("\n=== All Tests Passed ==="); diff --git a/libs/ai/examples/custom_headers.rs b/libs/ai/examples/custom_headers.rs index 73a681d0..ba725006 100644 --- a/libs/ai/examples/custom_headers.rs +++ b/libs/ai/examples/custom_headers.rs @@ -1,13 +1,16 @@ //! Example: Using custom headers with providers -use stakai::{GenerateRequest, Inference, Message, Role}; +use stakai::{GenerateRequest, Inference, Message, Model, Role}; #[tokio::main] async fn main() -> Result<(), Box> { let client = Inference::new(); // Create request with custom headers using GenerateOptions - let mut request = GenerateRequest::new("gpt-4", vec![Message::new(Role::User, "Hello!")]); + let mut request = GenerateRequest::new( + Model::custom("gpt-4", "openai"), + vec![Message::new(Role::User, "Hello!")], + ); request.options.max_tokens = Some(100); request.options = request .options diff --git a/libs/ai/examples/gemini_generate.rs b/libs/ai/examples/gemini_generate.rs index dbfbe3aa..f83ec77c 100644 --- a/libs/ai/examples/gemini_generate.rs +++ b/libs/ai/examples/gemini_generate.rs @@ -1,6 +1,6 @@ //! Example: Basic Gemini generation -use stakai::{GenerateRequest, Inference, Message, Role}; +use stakai::{GenerateRequest, Inference, Message, Model, Role}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -9,7 +9,7 @@ async fn main() -> Result<(), Box> { let client = Inference::new(); let mut request = GenerateRequest::new( - "gemini-2.0-flash-exp", + Model::custom("gemini-2.0-flash-exp", "google"), vec![ Message::new(Role::System, "You are a knowledgeable science teacher."), Message::new(Role::User, "What causes the northern lights?"), diff --git a/libs/ai/examples/gemini_stream.rs b/libs/ai/examples/gemini_stream.rs index b6e2e8f9..17b69e15 100644 --- a/libs/ai/examples/gemini_stream.rs +++ b/libs/ai/examples/gemini_stream.rs @@ -1,7 +1,7 @@ //! Example: Streaming with Gemini use futures::StreamExt; -use stakai::{GenerateRequest, Inference, Message, Role, StreamEvent}; +use stakai::{GenerateRequest, Inference, Message, Model, Role, StreamEvent}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -10,7 +10,7 @@ async fn main() -> Result<(), Box> { let client = Inference::new(); let mut request = GenerateRequest::new( - "gemini-1.5-flash", + Model::custom("gemini-1.5-flash", "google"), vec![Message::new( Role::User, "Tell me an interesting fact about space exploration.", diff --git a/libs/ai/examples/multi_provider.rs b/libs/ai/examples/multi_provider.rs index 71c8da60..2606e59a 100644 --- a/libs/ai/examples/multi_provider.rs +++ b/libs/ai/examples/multi_provider.rs @@ -1,6 +1,6 @@ //! Example: Comparing responses from multiple providers -use stakai::{GenerateRequest, Inference, Message, Role}; +use stakai::{GenerateRequest, Inference, Message, Model, Role}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -8,7 +8,10 @@ async fn main() -> Result<(), Box> { let question = "What is the meaning of life?"; - let mut request = GenerateRequest::new("gpt-4", vec![Message::new(Role::User, question)]); + let mut request = GenerateRequest::new( + Model::custom("gpt-4", "openai"), + vec![Message::new(Role::User, question)], + ); request.options.temperature = Some(0.7); request.options.max_tokens = Some(200); @@ -16,7 +19,7 @@ async fn main() -> Result<(), Box> { println!("{}", "=".repeat(80)); // Try OpenAI - request.model = "gpt-4".to_string(); + request.model = Model::custom("gpt-4", "openai"); if let Ok(response) = client.generate(&request).await { println!("\nšŸ¤– OpenAI GPT-4:"); println!("{}", response.text()); @@ -24,7 +27,7 @@ async fn main() -> Result<(), Box> { } // Try Anthropic - request.model = "claude-3-5-sonnet-20241022".to_string(); + request.model = Model::custom("claude-3-5-sonnet-20241022", "anthropic"); if let Ok(response) = client.generate(&request).await { println!("\nšŸ¤– Anthropic Claude:"); println!("{}", response.text()); @@ -32,7 +35,7 @@ async fn main() -> Result<(), Box> { } // Try Gemini - request.model = "gemini-2.0-flash-exp".to_string(); + request.model = Model::custom("gemini-2.0-flash-exp", "google"); if let Ok(response) = client.generate(&request).await { println!("\nšŸ¤– Google Gemini:"); println!("{}", response.text()); diff --git a/libs/ai/examples/simple.rs b/libs/ai/examples/simple.rs index bbec017a..39421630 100644 --- a/libs/ai/examples/simple.rs +++ b/libs/ai/examples/simple.rs @@ -1,12 +1,15 @@ //! Simplest possible example -use stakai::{GenerateRequest, Inference, Message, Role}; +use stakai::{GenerateRequest, Inference, Message, Model, Role}; #[tokio::main] async fn main() -> Result<(), Box> { let client = Inference::new(); - let request = GenerateRequest::new("gpt-4", vec![Message::new(Role::User, "What is 2+2?")]); + let request = GenerateRequest::new( + Model::custom("gpt-4", "openai"), + vec![Message::new(Role::User, "What is 2+2?")], + ); let response = client.generate(&request).await?; diff --git a/libs/ai/examples/streaming.rs b/libs/ai/examples/streaming.rs index 2ebddffb..bc86dbda 100644 --- a/libs/ai/examples/streaming.rs +++ b/libs/ai/examples/streaming.rs @@ -1,7 +1,7 @@ //! Streaming generation example use futures::StreamExt; -use stakai::{GenerateRequest, Inference, Message, Role, StreamEvent}; +use stakai::{GenerateRequest, Inference, Message, Model, Role, StreamEvent}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -10,7 +10,7 @@ async fn main() -> Result<(), Box> { // Build a request let mut request = GenerateRequest::new( - "gpt-4", + Model::custom("gpt-4", "openai"), vec![Message::new( Role::User, "Write a haiku about Rust programming", diff --git a/libs/ai/examples/tool_calling_basic.rs b/libs/ai/examples/tool_calling_basic.rs index cfad7201..00fc6d8c 100644 --- a/libs/ai/examples/tool_calling_basic.rs +++ b/libs/ai/examples/tool_calling_basic.rs @@ -8,7 +8,7 @@ //! - Getting the final response use serde_json::json; -use stakai::{ContentPart, GenerateRequest, Inference, Message, Role, Tool}; +use stakai::{ContentPart, GenerateRequest, Inference, Message, Model, Role, Tool}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -39,7 +39,7 @@ async fn main() -> Result<(), Box> { // 2. Create initial request with the user query and tool let mut request = GenerateRequest::new( - "gpt-4o-mini", + Model::custom("gpt-4o-mini", "openai"), vec![Message::new( Role::User, "What's the weather like in Tokyo, Japan?", @@ -101,7 +101,8 @@ async fn main() -> Result<(), Box> { )], )); - let mut request_with_result = GenerateRequest::new("gpt-4o-mini", messages); + let mut request_with_result = + GenerateRequest::new(Model::custom("gpt-4o-mini", "openai"), messages); request_with_result.options = request_with_result.options.add_tool( Tool::function("get_weather", "Get the current weather for a location").parameters(json!({ "type": "object", diff --git a/libs/ai/examples/tool_calling_streaming.rs b/libs/ai/examples/tool_calling_streaming.rs index be4da6c8..3515a029 100644 --- a/libs/ai/examples/tool_calling_streaming.rs +++ b/libs/ai/examples/tool_calling_streaming.rs @@ -9,7 +9,7 @@ use futures::StreamExt; use serde_json::json; -use stakai::{ContentPart, GenerateRequest, Inference, Message, StreamEvent, Tool}; +use stakai::{ContentPart, GenerateRequest, Inference, Message, Model, StreamEvent, Tool}; use std::collections::HashMap; #[tokio::main] @@ -49,7 +49,7 @@ async fn main() -> Result<(), Box> { // 2. Create initial request let mut request = GenerateRequest::new( - "gpt-4o-mini", + Model::custom("gpt-4o-mini", "openai"), vec![Message::new( stakai::Role::User, "What's the weather and time in Paris?", @@ -171,7 +171,7 @@ async fn main() -> Result<(), Box> { )); } - let mut follow_up = GenerateRequest::new("gpt-4o-mini", messages); + let mut follow_up = GenerateRequest::new(Model::custom("gpt-4o-mini", "openai"), messages); follow_up.options = follow_up.options.add_tool(weather_tool).add_tool(time_tool); // 7. Get final response diff --git a/libs/ai/src/client/mod.rs b/libs/ai/src/client/mod.rs index 06be1447..b2a1d8cb 100644 --- a/libs/ai/src/client/mod.rs +++ b/libs/ai/src/client/mod.rs @@ -6,7 +6,7 @@ mod config; pub use builder::ClientBuilder; pub use config::{ClientConfig, InferenceConfig}; -use crate::error::{Error, Result}; +use crate::error::Result; use crate::registry::ProviderRegistry; use crate::types::{GenerateRequest, GenerateResponse, GenerateStream}; @@ -72,11 +72,11 @@ impl Inference { /// # Example /// /// ```rust,no_run - /// # use stakai::{Inference, GenerateRequest, Message, Role}; + /// # use stakai::{Inference, GenerateRequest, Message, Model, Role}; /// # async fn example() -> Result<(), Box> { /// let client = Inference::new(); /// let request = GenerateRequest::new( - /// "openai/gpt-4", + /// Model::custom("gpt-4", "openai"), /// vec![Message::new(Role::User, "Hello!")] /// ); /// let response = client.generate(&request).await?; @@ -98,12 +98,11 @@ impl Inference { pub async fn generate(&self, request: &GenerateRequest) -> Result { #[cfg(feature = "tracing")] { - let (provider_id, _) = self.parse_model(&request.model)?; let span = tracing::info_span!( "chat", "gen_ai.operation.name" = "chat", - "gen_ai.provider.name" = %provider_id, - "gen_ai.request.model" = %request.model, + "gen_ai.provider.name" = %request.model.provider, + "gen_ai.request.model" = %request.model.id, "gen_ai.request.temperature" = tracing::field::Empty, "gen_ai.request.max_tokens" = tracing::field::Empty, "gen_ai.request.top_p" = tracing::field::Empty, @@ -188,12 +187,8 @@ impl Inference { /// Internal generate implementation async fn generate_internal(&self, request: &GenerateRequest) -> Result { - let (provider_id, model_id) = self.parse_model(&request.model)?; - let provider = self.registry.get_provider(&provider_id)?; - - let mut req = request.clone(); - req.model = model_id.to_string(); - provider.generate(req).await + let provider = self.registry.get_provider(&request.model.provider)?; + provider.generate(request.clone()).await } /// Generate a streaming response @@ -209,12 +204,12 @@ impl Inference { /// # Example /// /// ```rust,no_run - /// # use stakai::{Inference, GenerateRequest, Message, Role, StreamEvent}; + /// # use stakai::{Inference, GenerateRequest, Message, Model, Role, StreamEvent}; /// # use futures::StreamExt; /// # async fn example() -> Result<(), Box> { /// let client = Inference::new(); /// let request = GenerateRequest::new( - /// "openai/gpt-4", + /// Model::custom("gpt-4", "openai"), /// vec![Message::new(Role::User, "Count to 5")] /// ); /// let mut stream = client.stream(&request).await?; @@ -243,12 +238,11 @@ impl Inference { pub async fn stream(&self, request: &GenerateRequest) -> Result { #[cfg(feature = "tracing")] { - let (provider_id, _) = self.parse_model(&request.model)?; let span = tracing::info_span!( "chat", "gen_ai.operation.name" = "chat", - "gen_ai.provider.name" = %provider_id, - "gen_ai.request.model" = %request.model, + "gen_ai.provider.name" = %request.model.provider, + "gen_ai.request.model" = %request.model.id, "gen_ai.request.temperature" = tracing::field::Empty, "gen_ai.request.max_tokens" = tracing::field::Empty, "gen_ai.request.top_p" = tracing::field::Empty, @@ -304,39 +298,8 @@ impl Inference { /// Internal stream implementation async fn stream_internal(&self, request: &GenerateRequest) -> Result { - let (provider_id, model_id) = self.parse_model(&request.model)?; - let provider = self.registry.get_provider(&provider_id)?; - - let mut req = request.clone(); - req.model = model_id.to_string(); - provider.stream(req).await - } - - /// Parse model string into provider and model ID - pub(crate) fn parse_model<'a>(&self, model: &'a str) -> Result<(String, &'a str)> { - if let Some((provider, model_id)) = model.split_once('/') { - // Explicit provider/model format - Ok((provider.to_string(), model_id)) - } else { - // Auto-detect provider from model name - let provider = self.detect_provider(model)?; - Ok((provider, model)) - } - } - - /// Detect provider from model name using heuristics - pub(crate) fn detect_provider(&self, model: &str) -> Result { - let model_lower = model.to_lowercase(); - - if model_lower.starts_with("gpt-") || model_lower.starts_with("o1-") { - Ok("openai".to_string()) - } else if model_lower.starts_with("claude-") { - Ok("anthropic".to_string()) - } else if model_lower.starts_with("gemini-") { - Ok("google".to_string()) - } else { - Err(Error::UnknownProvider(model.to_string())) - } + let provider = self.registry.get_provider(&request.model.provider)?; + provider.stream(request.clone()).await } /// Get the provider registry diff --git a/libs/ai/src/lib.rs b/libs/ai/src/lib.rs index 90226da0..b15a0234 100644 --- a/libs/ai/src/lib.rs +++ b/libs/ai/src/lib.rs @@ -15,14 +15,14 @@ //! ## Quick Start //! //! ```rust,no_run -//! use stakai::{Inference, GenerateRequest, Message, Role}; +//! use stakai::{Inference, GenerateRequest, Message, Model, Role}; //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { //! let client = Inference::new(); //! //! let request = GenerateRequest::new( -//! "gpt-4", +//! Model::custom("gpt-4", "openai"), //! vec![Message::new(Role::User, "What is Rust?")] //! ); //! @@ -76,6 +76,10 @@ pub use types::{ Message, MessageContent, MessageProviderOptions, + // Model types + Model, + ModelCost, + ModelLimit, OpenAIOptions, OutputTokenDetails, PromptCacheRetention, diff --git a/libs/ai/src/provider/trait_def.rs b/libs/ai/src/provider/trait_def.rs index cb74f3b2..be7e4db9 100644 --- a/libs/ai/src/provider/trait_def.rs +++ b/libs/ai/src/provider/trait_def.rs @@ -1,7 +1,7 @@ //! Provider trait definition use crate::error::Result; -use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers}; +use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers, Model}; use async_trait::async_trait; /// Trait for AI provider implementations @@ -20,8 +20,13 @@ pub trait Provider: Send + Sync { /// Generate a streaming response async fn stream(&self, request: GenerateRequest) -> Result; - /// List available models (optional) - async fn list_models(&self) -> Result> { + /// List available models with full metadata + async fn list_models(&self) -> Result> { Ok(vec![]) } + + /// Get a specific model by ID (default implementation) + async fn get_model(&self, id: &str) -> Result> { + Ok(self.list_models().await?.into_iter().find(|m| m.id == id)) + } } diff --git a/libs/ai/src/providers/anthropic/convert.rs b/libs/ai/src/providers/anthropic/convert.rs index 232085f8..c5cd0c8c 100644 --- a/libs/ai/src/providers/anthropic/convert.rs +++ b/libs/ai/src/providers/anthropic/convert.rs @@ -47,7 +47,7 @@ pub fn to_anthropic_request( let max_tokens = req .options .max_tokens - .unwrap_or_else(|| infer_max_tokens(&req.model)); + .unwrap_or_else(|| infer_max_tokens(&req.model.id)); // Convert tools to Anthropic format with cache control let tools = build_tools(&req.options.tools, &mut validator)?; @@ -79,7 +79,7 @@ pub fn to_anthropic_request( Ok(AnthropicConversionResult { request: AnthropicRequest { - model: req.model.clone(), + model: req.model.id.clone(), messages, max_tokens, system, diff --git a/libs/ai/src/providers/anthropic/mod.rs b/libs/ai/src/providers/anthropic/mod.rs index 22cb196b..900fe6ff 100644 --- a/libs/ai/src/providers/anthropic/mod.rs +++ b/libs/ai/src/providers/anthropic/mod.rs @@ -1,6 +1,7 @@ //! Anthropic provider module mod convert; +pub mod models; mod provider; mod stream; mod types; diff --git a/libs/ai/src/providers/anthropic/models.rs b/libs/ai/src/providers/anthropic/models.rs new file mode 100644 index 00000000..deca5bbb --- /dev/null +++ b/libs/ai/src/providers/anthropic/models.rs @@ -0,0 +1,137 @@ +//! Anthropic model definitions +//! +//! Static model definitions for Anthropic's Claude models with pricing and limits. + +use crate::types::{Model, ModelCost, ModelLimit}; + +/// Provider identifier for Anthropic +pub const PROVIDER_ID: &str = "anthropic"; + +/// Get all Anthropic models +pub fn models() -> Vec { + vec![claude_haiku_4_5(), claude_sonnet_4_5(), claude_opus_4_5()] +} + +/// Get an Anthropic model by ID +pub fn get_model(id: &str) -> Option { + models().into_iter().find(|m| m.id == id) +} + +/// Get the default model for Anthropic +pub fn default_model() -> Model { + claude_sonnet_4_5() +} + +/// Claude Haiku 4.5 - Fast and affordable +pub fn claude_haiku_4_5() -> Model { + Model { + id: "claude-haiku-4-5-20251001".into(), + name: "Claude Haiku 4.5".into(), + provider: PROVIDER_ID.into(), + reasoning: false, + cost: Some(ModelCost { + input: 1.0, + output: 5.0, + cache_read: Some(0.10), + cache_write: Some(1.25), + }), + limit: ModelLimit { + context: 200_000, + output: 8_192, + }, + } +} + +/// Claude Sonnet 4.5 - Balanced performance +pub fn claude_sonnet_4_5() -> Model { + Model { + id: "claude-sonnet-4-5-20250929".into(), + name: "Claude Sonnet 4.5".into(), + provider: PROVIDER_ID.into(), + reasoning: true, + cost: Some(ModelCost { + input: 3.0, + output: 15.0, + cache_read: Some(0.30), + cache_write: Some(3.75), + }), + limit: ModelLimit { + context: 200_000, + output: 16_384, + }, + } +} + +/// Claude Opus 4.5 - Most capable +pub fn claude_opus_4_5() -> Model { + Model { + id: "claude-opus-4-5-20251101".into(), + name: "Claude Opus 4.5".into(), + provider: PROVIDER_ID.into(), + reasoning: true, + cost: Some(ModelCost { + input: 15.0, + output: 75.0, + cache_read: Some(1.50), + cache_write: Some(18.75), + }), + limit: ModelLimit { + context: 200_000, + output: 32_000, + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_models_list() { + let all_models = models(); + assert_eq!(all_models.len(), 3); + + // Verify all models have the anthropic provider + for model in &all_models { + assert_eq!(model.provider, PROVIDER_ID); + } + } + + #[test] + fn test_get_model() { + let sonnet = get_model("claude-sonnet-4-5-20250929"); + assert!(sonnet.is_some()); + assert_eq!(sonnet.unwrap().name, "Claude Sonnet 4.5"); + + let nonexistent = get_model("nonexistent-model"); + assert!(nonexistent.is_none()); + } + + #[test] + fn test_default_model() { + let default = default_model(); + assert_eq!(default.id, "claude-sonnet-4-5-20250929"); + } + + #[test] + fn test_model_pricing() { + let sonnet = claude_sonnet_4_5(); + let cost = sonnet.cost.unwrap(); + + assert_eq!(cost.input, 3.0); + assert_eq!(cost.output, 15.0); + assert_eq!(cost.cache_read, Some(0.30)); + assert_eq!(cost.cache_write, Some(3.75)); + } + + #[test] + fn test_reasoning_support() { + let haiku = claude_haiku_4_5(); + let sonnet = claude_sonnet_4_5(); + let opus = claude_opus_4_5(); + + assert!(!haiku.reasoning); + assert!(sonnet.reasoning); + assert!(opus.reasoning); + } +} diff --git a/libs/ai/src/providers/anthropic/provider.rs b/libs/ai/src/providers/anthropic/provider.rs index d3c9e7e8..ded9562d 100644 --- a/libs/ai/src/providers/anthropic/provider.rs +++ b/libs/ai/src/providers/anthropic/provider.rs @@ -1,11 +1,12 @@ //! Anthropic provider implementation use super::convert::{from_anthropic_response_with_warnings, to_anthropic_request}; +use super::models; use super::stream::create_stream; use super::types::{AnthropicConfig, AnthropicResponse}; use crate::error::{Error, Result}; use crate::provider::Provider; -use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers}; +use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers, Model}; use async_trait::async_trait; use reqwest::Client; use reqwest_eventsource::EventSource; @@ -57,42 +58,14 @@ impl Provider for AnthropicProvider { self.build_headers_with_cache(custom_headers, false) } - async fn list_models(&self) -> Result> { - let url = format!("{}models", self.config.base_url); - let headers = self.build_headers(None); - - let response = self - .client - .get(&url) - .headers(headers.to_reqwest_headers()) - .send() - .await?; - - if !response.status().is_success() { - let status = response.status(); - let error_text = response.text().await.unwrap_or_default(); - return Err(Error::provider_error(format!( - "Anthropic API error {}: {}", - status, error_text - ))); - } + async fn list_models(&self) -> Result> { + // Return static model definitions with full metadata + Ok(models::models()) + } - let resp: serde_json::Value = response.json().await?; - - // Extract model IDs from the response - // Response format: { "data": [{ "id": "model-id", ... }, ...] } - let models = resp - .get("data") - .and_then(|d| d.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|m| m.get("id").and_then(|id| id.as_str())) - .map(|s| s.to_string()) - .collect() - }) - .unwrap_or_default(); - - Ok(models) + async fn get_model(&self, id: &str) -> Result> { + // Use static lookup for efficiency + Ok(models::get_model(id)) } async fn generate(&self, request: GenerateRequest) -> Result { diff --git a/libs/ai/src/providers/gemini/mod.rs b/libs/ai/src/providers/gemini/mod.rs index bd488fee..8f952d5f 100644 --- a/libs/ai/src/providers/gemini/mod.rs +++ b/libs/ai/src/providers/gemini/mod.rs @@ -1,6 +1,7 @@ //! Gemini provider module mod convert; +pub mod models; mod provider; mod stream; mod types; diff --git a/libs/ai/src/providers/gemini/models.rs b/libs/ai/src/providers/gemini/models.rs new file mode 100644 index 00000000..745d353a --- /dev/null +++ b/libs/ai/src/providers/gemini/models.rs @@ -0,0 +1,178 @@ +//! Gemini model definitions +//! +//! Static model definitions for Google's Gemini models with pricing and limits. + +use crate::types::{Model, ModelCost, ModelLimit}; + +/// Provider identifier for Gemini (Google) +pub const PROVIDER_ID: &str = "google"; + +/// Get all Gemini models +pub fn models() -> Vec { + vec![ + gemini_3_pro(), + gemini_3_flash(), + gemini_2_5_pro(), + gemini_2_5_flash(), + gemini_2_5_flash_lite(), + ] +} + +/// Get a Gemini model by ID +pub fn get_model(id: &str) -> Option { + models().into_iter().find(|m| m.id == id) +} + +/// Get the default model for Gemini +pub fn default_model() -> Model { + gemini_3_pro() +} + +/// Gemini 3 Pro - Latest flagship model +pub fn gemini_3_pro() -> Model { + Model { + id: "gemini-3-pro-preview".into(), + name: "Gemini 3 Pro".into(), + provider: PROVIDER_ID.into(), + reasoning: true, + cost: Some(ModelCost { + input: 2.0, + output: 12.0, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 1_000_000, + output: 65_536, + }, + } +} + +/// Gemini 3 Flash - Fast latest generation +pub fn gemini_3_flash() -> Model { + Model { + id: "gemini-3-flash-preview".into(), + name: "Gemini 3 Flash".into(), + provider: PROVIDER_ID.into(), + reasoning: true, + cost: Some(ModelCost { + input: 0.50, + output: 3.0, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 1_000_000, + output: 65_536, + }, + } +} + +/// Gemini 2.5 Pro - Powerful multimodal model +pub fn gemini_2_5_pro() -> Model { + Model { + id: "gemini-2.5-pro".into(), + name: "Gemini 2.5 Pro".into(), + provider: PROVIDER_ID.into(), + reasoning: true, + cost: Some(ModelCost { + input: 1.25, + output: 10.0, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 1_000_000, + output: 65_536, + }, + } +} + +/// Gemini 2.5 Flash - Fast and efficient +pub fn gemini_2_5_flash() -> Model { + Model { + id: "gemini-2.5-flash".into(), + name: "Gemini 2.5 Flash".into(), + provider: PROVIDER_ID.into(), + reasoning: true, + cost: Some(ModelCost { + input: 0.30, + output: 2.50, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 1_000_000, + output: 65_536, + }, + } +} + +/// Gemini 2.5 Flash Lite - Smallest and most affordable +pub fn gemini_2_5_flash_lite() -> Model { + Model { + id: "gemini-2.5-flash-lite".into(), + name: "Gemini 2.5 Flash Lite".into(), + provider: PROVIDER_ID.into(), + reasoning: false, + cost: Some(ModelCost { + input: 0.10, + output: 0.40, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 1_000_000, + output: 65_536, + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_models_list() { + let all_models = models(); + assert_eq!(all_models.len(), 5); + + // Verify all models have the google provider + for model in &all_models { + assert_eq!(model.provider, PROVIDER_ID); + } + } + + #[test] + fn test_get_model() { + let pro = get_model("gemini-3-pro-preview"); + assert!(pro.is_some()); + assert_eq!(pro.unwrap().name, "Gemini 3 Pro"); + + let nonexistent = get_model("nonexistent-model"); + assert!(nonexistent.is_none()); + } + + #[test] + fn test_default_model() { + let default = default_model(); + assert_eq!(default.id, "gemini-3-pro-preview"); + } + + #[test] + fn test_large_context_windows() { + // All Gemini models have 1M token context + for model in models() { + assert_eq!(model.limit.context, 1_000_000); + } + } + + #[test] + fn test_reasoning_support() { + let flash_lite = gemini_2_5_flash_lite(); + let pro = gemini_3_pro(); + + assert!(!flash_lite.reasoning); + assert!(pro.reasoning); + } +} diff --git a/libs/ai/src/providers/gemini/provider.rs b/libs/ai/src/providers/gemini/provider.rs index 4b5f68e9..6c745cb9 100644 --- a/libs/ai/src/providers/gemini/provider.rs +++ b/libs/ai/src/providers/gemini/provider.rs @@ -1,11 +1,12 @@ //! Gemini provider implementation use super::convert::{from_gemini_response, to_gemini_request}; +use super::models; use super::stream::create_stream; use super::types::{GeminiConfig, GeminiResponse}; use crate::error::{Error, Result}; use crate::provider::Provider; -use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers}; +use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers, Model}; use async_trait::async_trait; use reqwest::Client; @@ -73,7 +74,7 @@ impl Provider for GeminiProvider { } async fn generate(&self, request: GenerateRequest) -> Result { - let url = self.get_url(&request.model, false); + let url = self.get_url(&request.model.id, false); let gemini_req = to_gemini_request(&request)?; let headers = self.build_headers(request.options.headers.as_ref()); @@ -104,7 +105,7 @@ impl Provider for GeminiProvider { } async fn stream(&self, request: GenerateRequest) -> Result { - let url = self.get_url(&request.model, true); + let url = self.get_url(&request.model.id, true); let gemini_req = to_gemini_request(&request)?; let headers = self.build_headers(request.options.headers.as_ref()); @@ -129,13 +130,13 @@ impl Provider for GeminiProvider { create_stream(response).await } - async fn list_models(&self) -> Result> { - // Gemini has a models endpoint, but for simplicity return known models - Ok(vec![ - "gemini-2.0-flash-exp".to_string(), - "gemini-1.5-pro".to_string(), - "gemini-1.5-flash".to_string(), - "gemini-1.0-pro".to_string(), - ]) + async fn list_models(&self) -> Result> { + // Return static model definitions with full metadata + Ok(models::models()) + } + + async fn get_model(&self, id: &str) -> Result> { + // Use static lookup for efficiency + Ok(models::get_model(id)) } } diff --git a/libs/ai/src/providers/openai/convert.rs b/libs/ai/src/providers/openai/convert.rs index 2715a9df..65aa5283 100644 --- a/libs/ai/src/providers/openai/convert.rs +++ b/libs/ai/src/providers/openai/convert.rs @@ -60,7 +60,7 @@ pub fn to_openai_request(req: &GenerateRequest, stream: bool) -> ChatCompletionR } }) .unwrap_or_else(|| { - if is_reasoning_model(&req.model) { + if is_reasoning_model(&req.model.id) { SystemMessageMode::Developer } else { SystemMessageMode::System @@ -74,13 +74,13 @@ pub fn to_openai_request(req: &GenerateRequest, stream: bool) -> ChatCompletionR .filter_map(|msg| to_openai_message_with_mode(msg, system_message_mode)) .collect(); - let temp = match is_reasoning_model(&req.model) { + let temp = match is_reasoning_model(&req.model.id) { false => Some(0.0), true => None, }; ChatCompletionRequest { - model: req.model.clone(), + model: req.model.id.clone(), messages, temperature: temp, max_completion_tokens: req.options.max_tokens, diff --git a/libs/ai/src/providers/openai/mod.rs b/libs/ai/src/providers/openai/mod.rs index 07084e2e..b75bd14a 100644 --- a/libs/ai/src/providers/openai/mod.rs +++ b/libs/ai/src/providers/openai/mod.rs @@ -2,6 +2,7 @@ pub mod convert; mod error; +pub mod models; mod provider; pub mod stream; pub mod types; diff --git a/libs/ai/src/providers/openai/models.rs b/libs/ai/src/providers/openai/models.rs new file mode 100644 index 00000000..605083e9 --- /dev/null +++ b/libs/ai/src/providers/openai/models.rs @@ -0,0 +1,193 @@ +//! OpenAI model definitions +//! +//! Static model definitions for OpenAI's GPT and reasoning models with pricing and limits. + +use crate::types::{Model, ModelCost, ModelLimit}; + +/// Provider identifier for OpenAI +pub const PROVIDER_ID: &str = "openai"; + +/// Get all OpenAI models +pub fn models() -> Vec { + vec![ + gpt_5(), + gpt_5_1(), + gpt_5_mini(), + gpt_5_nano(), + o3(), + o4_mini(), + ] +} + +/// Get an OpenAI model by ID +pub fn get_model(id: &str) -> Option { + models().into_iter().find(|m| m.id == id) +} + +/// Get the default model for OpenAI +pub fn default_model() -> Model { + gpt_5() +} + +/// GPT-5 - Main flagship model +pub fn gpt_5() -> Model { + Model { + id: "gpt-5-2025-08-07".into(), + name: "GPT-5".into(), + provider: PROVIDER_ID.into(), + reasoning: false, + cost: Some(ModelCost { + input: 1.25, + output: 10.0, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 400_000, + output: 16_384, + }, + } +} + +/// GPT-5.1 - Updated flagship model +pub fn gpt_5_1() -> Model { + Model { + id: "gpt-5.1-2025-11-13".into(), + name: "GPT-5.1".into(), + provider: PROVIDER_ID.into(), + reasoning: false, + cost: Some(ModelCost { + input: 1.50, + output: 12.0, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 400_000, + output: 16_384, + }, + } +} + +/// GPT-5 Mini - Smaller, faster model +pub fn gpt_5_mini() -> Model { + Model { + id: "gpt-5-mini-2025-08-07".into(), + name: "GPT-5 Mini".into(), + provider: PROVIDER_ID.into(), + reasoning: false, + cost: Some(ModelCost { + input: 0.25, + output: 2.0, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 400_000, + output: 16_384, + }, + } +} + +/// GPT-5 Nano - Smallest and fastest +pub fn gpt_5_nano() -> Model { + Model { + id: "gpt-5-nano-2025-08-07".into(), + name: "GPT-5 Nano".into(), + provider: PROVIDER_ID.into(), + reasoning: false, + cost: Some(ModelCost { + input: 0.05, + output: 0.40, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 400_000, + output: 16_384, + }, + } +} + +/// O3 - Advanced reasoning model +pub fn o3() -> Model { + Model { + id: "o3-2025-04-16".into(), + name: "O3".into(), + provider: PROVIDER_ID.into(), + reasoning: true, + cost: Some(ModelCost { + input: 2.0, + output: 8.0, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 200_000, + output: 100_000, + }, + } +} + +/// O4 Mini - Smaller reasoning model +pub fn o4_mini() -> Model { + Model { + id: "o4-mini-2025-04-16".into(), + name: "O4 Mini".into(), + provider: PROVIDER_ID.into(), + reasoning: true, + cost: Some(ModelCost { + input: 1.10, + output: 4.40, + cache_read: None, + cache_write: None, + }), + limit: ModelLimit { + context: 200_000, + output: 100_000, + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_models_list() { + let all_models = models(); + assert_eq!(all_models.len(), 6); + + // Verify all models have the openai provider + for model in &all_models { + assert_eq!(model.provider, PROVIDER_ID); + } + } + + #[test] + fn test_get_model() { + let gpt5 = get_model("gpt-5-2025-08-07"); + assert!(gpt5.is_some()); + assert_eq!(gpt5.unwrap().name, "GPT-5"); + + let nonexistent = get_model("nonexistent-model"); + assert!(nonexistent.is_none()); + } + + #[test] + fn test_default_model() { + let default = default_model(); + assert_eq!(default.id, "gpt-5-2025-08-07"); + } + + #[test] + fn test_reasoning_models() { + let gpt5 = gpt_5(); + let o3_model = o3(); + let o4_model = o4_mini(); + + assert!(!gpt5.reasoning); + assert!(o3_model.reasoning); + assert!(o4_model.reasoning); + } +} diff --git a/libs/ai/src/providers/openai/provider.rs b/libs/ai/src/providers/openai/provider.rs index 9580ffc9..95412f8b 100644 --- a/libs/ai/src/providers/openai/provider.rs +++ b/libs/ai/src/providers/openai/provider.rs @@ -1,11 +1,12 @@ //! OpenAI provider implementation use super::convert::{from_openai_response, to_openai_request}; +use super::models; use super::stream::create_stream; use super::types::{ChatCompletionResponse, OpenAIConfig}; use crate::error::{Error, Result}; use crate::provider::Provider; -use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers}; +use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers, Model}; use async_trait::async_trait; use reqwest::Client; use reqwest_eventsource::EventSource; @@ -105,12 +106,13 @@ impl Provider for OpenAIProvider { create_stream(event_source).await } - async fn list_models(&self) -> Result> { - // Simplified - in production, call /v1/models endpoint - Ok(vec![ - "gpt-4".to_string(), - "gpt-4-turbo-preview".to_string(), - "gpt-3.5-turbo".to_string(), - ]) + async fn list_models(&self) -> Result> { + // Return static model definitions with full metadata + Ok(models::models()) + } + + async fn get_model(&self, id: &str) -> Result> { + // Use static lookup for efficiency + Ok(models::get_model(id)) } } diff --git a/libs/ai/src/providers/stakpak/provider.rs b/libs/ai/src/providers/stakpak/provider.rs index 06d39686..117510be 100644 --- a/libs/ai/src/providers/stakpak/provider.rs +++ b/libs/ai/src/providers/stakpak/provider.rs @@ -9,7 +9,7 @@ use crate::provider::Provider; use crate::providers::openai::convert::{from_openai_response, to_openai_request}; use crate::providers::openai::stream::create_stream; use crate::providers::openai::types::ChatCompletionResponse; -use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers}; +use crate::types::{GenerateRequest, GenerateResponse, GenerateStream, Headers, Model}; use async_trait::async_trait; use reqwest::Client; use reqwest_eventsource::EventSource; @@ -108,16 +108,49 @@ impl Provider for StakpakProvider { create_stream(event_source).await } - async fn list_models(&self) -> Result> { - // Stakpak supports routing to various providers - Ok(vec![ - "anthropic/claude-sonnet-4-5-20250929".to_string(), - "anthropic/claude-haiku-4-5-20250929".to_string(), - "anthropic/claude-opus-4-5-20250929".to_string(), - "openai/gpt-5".to_string(), - "openai/gpt-5-mini".to_string(), - "google/gemini-2.5-flash".to_string(), - "google/gemini-2.5-pro".to_string(), - ]) + async fn list_models(&self) -> Result> { + // Stakpak routes to other providers, so aggregate models from them + // with stakpak/ prefix for routing + use crate::providers::{anthropic, gemini, openai}; + + let mut models = Vec::new(); + + // Add Anthropic models with stakpak/anthropic/ prefix + for model in anthropic::models::models() { + models.push(Model { + id: format!("anthropic/{}", model.id), + name: model.name, + provider: "stakpak".into(), + reasoning: model.reasoning, + cost: model.cost, + limit: model.limit, + }); + } + + // Add OpenAI models with stakpak/openai/ prefix + for model in openai::models::models() { + models.push(Model { + id: format!("openai/{}", model.id), + name: model.name, + provider: "stakpak".into(), + reasoning: model.reasoning, + cost: model.cost, + limit: model.limit, + }); + } + + // Add Gemini models with stakpak/google/ prefix + for model in gemini::models::models() { + models.push(Model { + id: format!("google/{}", model.id), + name: model.name, + provider: "stakpak".into(), + reasoning: model.reasoning, + cost: model.cost, + limit: model.limit, + }); + } + + Ok(models) } } diff --git a/libs/ai/src/registry/mod.rs b/libs/ai/src/registry/mod.rs index 89c0e5f4..e5941634 100644 --- a/libs/ai/src/registry/mod.rs +++ b/libs/ai/src/registry/mod.rs @@ -2,6 +2,7 @@ use crate::error::{Error, Result}; use crate::provider::Provider; +use crate::types::Model; use std::collections::HashMap; use std::sync::Arc; @@ -42,6 +43,34 @@ impl ProviderRegistry { pub fn has_provider(&self, id: &str) -> bool { self.providers.contains_key(id) } + + /// Get all models from all configured providers + pub async fn models(&self) -> Result> { + let mut all_models = Vec::new(); + for provider in self.providers.values() { + all_models.extend(provider.list_models().await?); + } + Ok(all_models) + } + + /// Find a model by ID across all configured providers + pub async fn get_model(&self, id: &str) -> Result> { + for provider in self.providers.values() { + if let Some(model) = provider.get_model(id).await? { + return Ok(Some(model)); + } + } + Ok(None) + } + + /// Get all models from a specific provider + pub async fn models_for_provider(&self, provider_id: &str) -> Result> { + if let Some(provider) = self.providers.get(provider_id) { + provider.list_models().await + } else { + Err(Error::ProviderNotFound(provider_id.to_string())) + } + } } impl Default for ProviderRegistry { diff --git a/libs/ai/src/types/mod.rs b/libs/ai/src/types/mod.rs index 05f247a5..d3056303 100644 --- a/libs/ai/src/types/mod.rs +++ b/libs/ai/src/types/mod.rs @@ -4,6 +4,7 @@ mod cache; mod cache_validator; mod headers; mod message; +mod model; mod options; mod request; mod response; @@ -41,3 +42,6 @@ pub use response::{ // Stream types pub use stream::{GenerateStream, StreamEvent}; + +// Model types +pub use model::{Model, ModelCost, ModelLimit}; diff --git a/libs/ai/src/types/model.rs b/libs/ai/src/types/model.rs new file mode 100644 index 00000000..f71c1b1b --- /dev/null +++ b/libs/ai/src/types/model.rs @@ -0,0 +1,275 @@ +//! Unified model types for all AI providers +//! +//! This module provides a single `Model` struct that replaces provider-specific +//! model enums (AnthropicModel, OpenAIModel, GeminiModel) and related types. + +use serde::{Deserialize, Serialize}; + +/// Unified model representation across all providers +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Model { + /// Model identifier sent to the API (e.g., "claude-sonnet-4-5-20250929") + pub id: String, + /// Human-readable name (e.g., "Claude Sonnet 4.5") + pub name: String, + /// Provider identifier (e.g., "anthropic", "openai", "google") + pub provider: String, + /// Extended thinking/reasoning support + pub reasoning: bool, + /// Pricing per 1M tokens (None for custom/unknown models) + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, + /// Token limits + pub limit: ModelLimit, +} + +impl Model { + /// Create a new model with all fields + pub fn new( + id: impl Into, + name: impl Into, + provider: impl Into, + reasoning: bool, + cost: Option, + limit: ModelLimit, + ) -> Self { + Self { + id: id.into(), + name: name.into(), + provider: provider.into(), + reasoning, + cost, + limit, + } + } + + /// Create a custom model with minimal info (no pricing) + pub fn custom(id: impl Into, provider: impl Into) -> Self { + let id = id.into(); + Self { + name: id.clone(), + id, + provider: provider.into(), + reasoning: false, + cost: None, + limit: ModelLimit::default(), + } + } + + /// Check if this model has pricing information + pub fn has_pricing(&self) -> bool { + self.cost.is_some() + } + + /// Get the display name (name field) + pub fn display_name(&self) -> &str { + &self.name + } + + /// Get the model ID used for API calls + pub fn model_id(&self) -> &str { + &self.id + } + + /// Get the provider name + pub fn provider_name(&self) -> &str { + &self.provider + } +} + +impl std::fmt::Display for Model { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl Default for Model { + fn default() -> Self { + Self { + id: String::new(), + name: String::new(), + provider: String::new(), + reasoning: false, + cost: None, + limit: ModelLimit::default(), + } + } +} + +/// Pricing information per 1M tokens +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ModelCost { + /// Cost per 1M input tokens + pub input: f64, + /// Cost per 1M output tokens + pub output: f64, + /// Cost per 1M cached input tokens (if supported) + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read: Option, + /// Cost per 1M tokens written to cache (if supported) + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write: Option, +} + +impl ModelCost { + /// Create a new cost struct with basic input/output pricing + pub fn new(input: f64, output: f64) -> Self { + Self { + input, + output, + cache_read: None, + cache_write: None, + } + } + + /// Create a cost struct with cache pricing + pub fn with_cache(input: f64, output: f64, cache_read: f64, cache_write: f64) -> Self { + Self { + input, + output, + cache_read: Some(cache_read), + cache_write: Some(cache_write), + } + } + + /// Calculate cost for given token counts (in tokens, not millions) + pub fn calculate(&self, input_tokens: u64, output_tokens: u64) -> f64 { + let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input; + let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output; + input_cost + output_cost + } + + /// Calculate cost with cache tokens + pub fn calculate_with_cache( + &self, + input_tokens: u64, + output_tokens: u64, + cache_read_tokens: u64, + cache_write_tokens: u64, + ) -> f64 { + let base_cost = self.calculate(input_tokens, output_tokens); + let cache_read_cost = self + .cache_read + .map(|rate| (cache_read_tokens as f64 / 1_000_000.0) * rate) + .unwrap_or(0.0); + let cache_write_cost = self + .cache_write + .map(|rate| (cache_write_tokens as f64 / 1_000_000.0) * rate) + .unwrap_or(0.0); + base_cost + cache_read_cost + cache_write_cost + } +} + +/// Token limits for the model +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ModelLimit { + /// Maximum context window size in tokens + pub context: u64, + /// Maximum output tokens + pub output: u64, +} + +impl ModelLimit { + /// Create a new limit struct + pub fn new(context: u64, output: u64) -> Self { + Self { context, output } + } +} + +impl Default for ModelLimit { + fn default() -> Self { + Self { + context: 128_000, + output: 8_192, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_creation() { + let model = Model::new( + "claude-sonnet-4-5-20250929", + "Claude Sonnet 4.5", + "anthropic", + true, + Some(ModelCost::with_cache(3.0, 15.0, 0.30, 3.75)), + ModelLimit::new(200_000, 16_384), + ); + + assert_eq!(model.id, "claude-sonnet-4-5-20250929"); + assert_eq!(model.name, "Claude Sonnet 4.5"); + assert_eq!(model.provider, "anthropic"); + assert!(model.reasoning); + assert!(model.has_pricing()); + } + + #[test] + fn test_custom_model() { + let model = Model::custom("llama3", "ollama"); + + assert_eq!(model.id, "llama3"); + assert_eq!(model.name, "llama3"); + assert_eq!(model.provider, "ollama"); + assert!(!model.reasoning); + assert!(!model.has_pricing()); + } + + #[test] + fn test_cost_calculation() { + let cost = ModelCost::new(3.0, 15.0); + + // 1000 input tokens, 500 output tokens + let total = cost.calculate(1000, 500); + // (1000/1M) * 3.0 + (500/1M) * 15.0 = 0.003 + 0.0075 = 0.0105 + assert!((total - 0.0105).abs() < 0.0001); + } + + #[test] + fn test_cost_with_cache() { + let cost = ModelCost::with_cache(3.0, 15.0, 0.30, 3.75); + + let total = cost.calculate_with_cache(1000, 500, 2000, 1000); + // base: 0.0105 + // cache_read: (2000/1M) * 0.30 = 0.0006 + // cache_write: (1000/1M) * 3.75 = 0.00375 + // total: 0.0105 + 0.0006 + 0.00375 = 0.01485 + assert!((total - 0.01485).abs() < 0.0001); + } + + #[test] + fn test_model_display() { + let model = Model::new( + "gpt-5", + "GPT-5", + "openai", + false, + None, + ModelLimit::default(), + ); + + assert_eq!(format!("{}", model), "GPT-5"); + } + + #[test] + fn test_serialization() { + let model = Model::new( + "claude-sonnet-4-5-20250929", + "Claude Sonnet 4.5", + "anthropic", + true, + Some(ModelCost::new(3.0, 15.0)), + ModelLimit::new(200_000, 16_384), + ); + + let json = serde_json::to_string(&model).unwrap(); + assert!(json.contains("\"id\":\"claude-sonnet-4-5-20250929\"")); + assert!(json.contains("\"provider\":\"anthropic\"")); + + let deserialized: Model = serde_json::from_str(&json).unwrap(); + assert_eq!(model, deserialized); + } +} diff --git a/libs/ai/src/types/request.rs b/libs/ai/src/types/request.rs index 7cd4d38d..f363bab6 100644 --- a/libs/ai/src/types/request.rs +++ b/libs/ai/src/types/request.rs @@ -1,16 +1,17 @@ //! Request types for AI generation use super::cache::PromptCacheRetention; +use super::model::Model; use super::{GenerateOptions, Message}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; /// Request for generating AI completions -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct GenerateRequest { - /// Model identifier (can be provider-prefixed like "openai/gpt-4") + /// Model to use for generation #[serde(skip)] - pub model: String, + pub model: Model, /// Conversation messages pub messages: Vec, @@ -31,17 +32,18 @@ pub struct GenerateRequest { /// # Example /// /// ```rust - /// use stakai::GenerateRequest; + /// use stakai::{GenerateRequest, Message, Model, Role}; /// use std::collections::HashMap; /// /// let mut metadata = HashMap::new(); /// metadata.insert("user.id".to_string(), "user-123".to_string()); /// metadata.insert("session.id".to_string(), "session-456".to_string()); /// - /// let request = GenerateRequest { - /// telemetry_metadata: Some(metadata), - /// ..Default::default() - /// }; + /// let mut request = GenerateRequest::new( + /// Model::custom("gpt-4", "openai"), + /// vec![Message::new(Role::User, "Hello")] + /// ); + /// request.telemetry_metadata = Some(metadata); /// ``` #[serde(skip_serializing_if = "Option::is_none")] pub telemetry_metadata: Option>, @@ -198,9 +200,9 @@ pub struct GoogleOptions { impl GenerateRequest { /// Create a new request with model and messages - pub fn new(model: impl Into, messages: Vec) -> Self { + pub fn new(model: Model, messages: Vec) -> Self { Self { - model: model.into(), + model, messages, options: GenerateOptions::default(), provider_options: None, diff --git a/libs/ai/tests/integration/anthropic.rs b/libs/ai/tests/integration/anthropic.rs index da146c4d..fd8dd03c 100644 --- a/libs/ai/tests/integration/anthropic.rs +++ b/libs/ai/tests/integration/anthropic.rs @@ -4,7 +4,8 @@ use futures::StreamExt; use stakai::{ - GenerateRequest, Inference, InferenceConfig, Message, Role, StreamEvent, Tool, ToolChoice, + GenerateRequest, Inference, InferenceConfig, Message, Model, Role, StreamEvent, Tool, + ToolChoice, }; // ============================================================================= @@ -17,7 +18,7 @@ async fn test_anthropic_generate() { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![Message { role: Role::User, content: "Say 'Hello, World!' and nothing else".into(), @@ -45,7 +46,7 @@ async fn test_anthropic_generate_with_system_message() { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![ Message { role: Role::System, @@ -79,7 +80,7 @@ async fn test_anthropic_explicit_provider_prefix() { let client = Inference::new(); let mut request = GenerateRequest::new( - "anthropic/claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![Message { role: Role::User, content: "Say 'test'".into(), @@ -108,7 +109,7 @@ async fn test_anthropic_streaming() { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![Message { role: Role::User, content: "Count from 1 to 5".into(), @@ -158,7 +159,7 @@ async fn test_anthropic_streaming_with_system_message() { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![ Message { role: Role::System, @@ -220,7 +221,7 @@ async fn test_anthropic_tool_calling() { })); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![Message { role: Role::User, content: "What's the weather in Tokyo?".into(), @@ -272,7 +273,7 @@ async fn test_anthropic_tool_calling_streaming() { })); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![Message { role: Role::User, content: "What is 123 * 456?".into(), @@ -339,7 +340,7 @@ async fn test_anthropic_custom_base_url_with_messages_suffix() { let client = Inference::with_config(config).expect("Failed to create client"); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![Message { role: Role::User, content: "Say 'URL test passed'".into(), @@ -373,7 +374,7 @@ async fn test_anthropic_custom_base_url_without_trailing_slash() { let client = Inference::with_config(config).expect("Failed to create client"); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![Message { role: Role::User, content: "Say 'slash test passed'".into(), @@ -403,7 +404,7 @@ async fn test_anthropic_invalid_model_error() { let client = Inference::new(); let request = GenerateRequest::new( - "claude-invalid-model-12345", + Model::custom("claude-invalid-model-12345", "anthropic"), vec![Message { role: Role::User, content: "Test".into(), @@ -425,7 +426,7 @@ async fn test_anthropic_streaming_invalid_model_error() { let client = Inference::new(); let request = GenerateRequest::new( - "claude-invalid-model-12345", + Model::custom("claude-invalid-model-12345", "anthropic"), vec![Message { role: Role::User, content: "Test".into(), @@ -473,7 +474,7 @@ async fn test_anthropic_missing_api_key_error() { // try to make a request that would fail if let Ok(client) = client_result { let request = GenerateRequest::new( - "anthropic/claude-haiku-4-5-20251001", // Explicitly use anthropic + Model::custom("claude-haiku-4-5-20251001", "anthropic"), // Explicitly use anthropic vec![Message { role: Role::User, content: "Test".into(), @@ -505,7 +506,7 @@ async fn test_anthropic_multi_turn_conversation() { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![ Message { role: Role::User, @@ -554,7 +555,7 @@ async fn test_anthropic_streaming_long_response() { let client = Inference::new(); let mut request = GenerateRequest::new( - "claude-haiku-4-5-20251001", + Model::custom("claude-haiku-4-5-20251001", "anthropic"), vec![Message { role: Role::User, content: diff --git a/libs/ai/tests/integration/gemini.rs b/libs/ai/tests/integration/gemini.rs index 6e8d2104..0fad8ffe 100644 --- a/libs/ai/tests/integration/gemini.rs +++ b/libs/ai/tests/integration/gemini.rs @@ -3,7 +3,7 @@ //! Run with: cargo test --test integration -- --ignored use futures::StreamExt; -use stakai::{GenerateRequest, Inference, Message, Role, StreamEvent}; +use stakai::{GenerateRequest, Inference, Message, Model, Role, StreamEvent}; #[tokio::test] #[ignore] // Requires GEMINI_API_KEY @@ -11,7 +11,7 @@ async fn test_gemini_generate() { let client = Inference::new(); let mut request = GenerateRequest::new( - "gemini-2.5-flash-lite-preview-09-2025", + Model::custom("gemini-2.5-flash-lite-preview-09-2025", "google"), vec![Message { role: Role::User, content: "Say 'Hello, World!' and nothing else".into(), @@ -36,7 +36,7 @@ async fn test_gemini_streaming() { let client = Inference::new(); let mut request = GenerateRequest::new( - "gemini-2.5-flash-lite-preview-09-2025", + Model::custom("gemini-2.5-flash-lite-preview-09-2025", "google"), vec![Message { role: Role::User, content: "Count from 1 to 3".into(), diff --git a/libs/ai/tests/integration/openai.rs b/libs/ai/tests/integration/openai.rs index 289e1f07..fb54a1f2 100644 --- a/libs/ai/tests/integration/openai.rs +++ b/libs/ai/tests/integration/openai.rs @@ -3,7 +3,7 @@ //! Run with: cargo test --test integration -- --ignored use futures::StreamExt; -use stakai::{GenerateRequest, Inference, Message, Role, StreamEvent}; +use stakai::{GenerateRequest, Inference, Message, Model, Role, StreamEvent}; #[tokio::test] #[ignore] // Requires OPENAI_API_KEY @@ -11,7 +11,7 @@ async fn test_openai_generate() { let client = Inference::new(); let mut request = GenerateRequest::new( - "gpt-5-mini-2025-08-07", + Model::custom("gpt-5-mini-2025-08-07", "openai"), vec![Message { role: Role::User, content: "Say 'Hello, World!' and nothing else".into(), @@ -37,7 +37,7 @@ async fn test_openai_streaming() { let client = Inference::new(); let mut request = GenerateRequest::new( - "gpt-5-mini-2025-08-07", + Model::custom("gpt-5-mini-2025-08-07", "openai"), vec![Message { role: Role::User, content: "Count from 1 to 3".into(), @@ -79,7 +79,7 @@ async fn test_openai_with_system_message() { let client = Inference::new(); let mut request = GenerateRequest::new( - "gpt-5-mini-2025-08-07", + Model::custom("gpt-5-mini-2025-08-07", "openai"), vec![ Message { role: Role::System, @@ -112,7 +112,7 @@ async fn test_openai_explicit_provider() { let client = Inference::new(); let request = GenerateRequest::new( - "openai/gpt-3.5-turbo", + Model::custom("gpt-3.5-turbo", "openai"), vec![Message { role: Role::User, content: "Say hello".into(), @@ -134,7 +134,7 @@ async fn test_openai_temperature_variation() { // Test with temperature 0 (deterministic) let mut request = GenerateRequest::new( - "gpt-3.5-turbo", + Model::custom("gpt-3.5-turbo", "openai"), vec![Message { role: Role::User, content: "Say exactly: 'Test'".into(), diff --git a/libs/ai/tests/unit/client.rs b/libs/ai/tests/unit/client.rs index a6da9f18..38584ff3 100644 --- a/libs/ai/tests/unit/client.rs +++ b/libs/ai/tests/unit/client.rs @@ -2,7 +2,7 @@ use stakai::providers::openai::{OpenAIConfig, OpenAIProvider}; use stakai::registry::ProviderRegistry; -use stakai::{GenerateRequest, Inference, Message, Role}; +use stakai::{GenerateRequest, Inference, Message, Model, Role}; #[test] fn test_client_creation() { @@ -55,7 +55,7 @@ fn test_registry_list_providers() { #[test] fn test_request_creation() { let mut request = GenerateRequest::new( - "openai/gpt-4", + Model::custom("gpt-4", "openai"), vec![ Message { role: Role::System, @@ -82,7 +82,7 @@ fn test_request_creation() { #[test] fn test_request_with_model() { let request = GenerateRequest::new( - "gpt-4", + Model::custom("gpt-4", "openai"), vec![Message { role: Role::User, content: "Hello".into(), @@ -91,7 +91,7 @@ fn test_request_with_model() { }], ); - assert_eq!(request.model, "gpt-4"); + assert_eq!(request.model.id, "gpt-4"); assert_eq!(request.messages.len(), 1); } @@ -112,7 +112,7 @@ fn test_request_multiple_messages() { }, ]; - let request = GenerateRequest::new("openai/gpt-4", messages); + let request = GenerateRequest::new(Model::custom("gpt-4", "openai"), messages); assert_eq!(request.messages.len(), 2); } diff --git a/libs/ai/tests/unit/types.rs b/libs/ai/tests/unit/types.rs index 18f890f1..32c55124 100644 --- a/libs/ai/tests/unit/types.rs +++ b/libs/ai/tests/unit/types.rs @@ -1,5 +1,6 @@ //! Unit tests for core types +use stakai::Model; use stakai::types::*; #[test] @@ -85,7 +86,7 @@ fn test_content_part_image_with_detail() { #[test] fn test_generate_request_creation() { let mut request = GenerateRequest::new( - "openai/gpt-4", + Model::custom("gpt-4", "openai"), vec![Message { role: Role::User, content: "Hello".into(), @@ -104,7 +105,7 @@ fn test_generate_request_creation() { #[test] fn test_generate_request_simple() { let request = GenerateRequest::new( - "openai/gpt-4", + Model::custom("gpt-4", "openai"), vec![Message { role: Role::User, content: "Hello".into(), diff --git a/libs/api/Cargo.toml b/libs/api/Cargo.toml index 5d1cbffe..96eca439 100644 --- a/libs/api/Cargo.toml +++ b/libs/api/Cargo.toml @@ -9,6 +9,7 @@ homepage = { workspace = true } [dependencies] stakpak-shared = { workspace = true } +stakai = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } uuid = { workspace = true } diff --git a/libs/api/src/client/mod.rs b/libs/api/src/client/mod.rs index fa9b2ba4..fbf14c92 100644 --- a/libs/api/src/client/mod.rs +++ b/libs/api/src/client/mod.rs @@ -17,7 +17,7 @@ use crate::stakpak::{StakpakApiClient, StakpakApiConfig}; use libsql::Connection; use stakpak_shared::hooks::{HookRegistry, LifecycleEvent}; use stakpak_shared::models::llm::{LLMModel, LLMProviderConfig, ProviderConfig}; -use stakpak_shared::models::stakai_adapter::{StakAIClient, get_stakai_model_string}; +use stakpak_shared::models::stakai_adapter::StakAIClient; use std::path::PathBuf; use std::sync::Arc; @@ -228,16 +228,13 @@ impl AgentClient { }; // 6. Setup hook registry with context management hooks - let mut hook_registry = config.hook_registry.unwrap_or_default(); + let mut hook_registry = config + .hook_registry + .unwrap_or_else(|| HookRegistry::default()); hook_registry.register( LifecycleEvent::BeforeInference, Box::new(InlineScratchpadContextHook::new( InlineScratchpadContextHookOptions { - model_options: crate::local::ModelOptions { - smart_model: model_options.smart_model.clone(), - eco_model: model_options.eco_model.clone(), - recovery_model: model_options.recovery_model.clone(), - }, history_action_message_size_limit: Some(100), history_action_message_keep_last_n: Some(1), history_action_result_keep_last_n: Some(50), @@ -293,50 +290,6 @@ impl AgentClient { pub fn model_options(&self) -> &ModelOptions { &self.model_options } - - /// Get the model string for the given agent model type - /// - /// When Stakpak is available, routes through Stakpak provider. - /// Otherwise, uses direct provider. - pub fn get_model_string( - &self, - model: &stakpak_shared::models::integrations::openai::AgentModel, - ) -> LLMModel { - use stakpak_shared::models::integrations::openai::AgentModel; - - let base_model = match model { - AgentModel::Smart => self.model_options.smart_model.clone().unwrap_or_else(|| { - LLMModel::from("anthropic/claude-sonnet-4-5-20250929".to_string()) - }), - AgentModel::Eco => self.model_options.eco_model.clone().unwrap_or_else(|| { - LLMModel::from("anthropic/claude-haiku-4-5-20250929".to_string()) - }), - AgentModel::Recovery => self - .model_options - .recovery_model - .clone() - .unwrap_or_else(|| LLMModel::from("openai/gpt-5".to_string())), - }; - - // If Stakpak is available, route through Stakpak provider - if self.has_stakpak() { - // Get properly formatted model string with provider prefix (e.g., "anthropic/claude-sonnet-4-5") - let model_str = get_stakai_model_string(&base_model); - // Extract display name from the last segment for UI - let display_name = model_str - .rsplit('/') - .next() - .unwrap_or(&model_str) - .to_string(); - LLMModel::Custom { - provider: "stakpak".to_string(), - model: model_str, - name: Some(display_name), - } - } else { - base_model - } - } } // Debug implementation for AgentClient diff --git a/libs/api/src/client/provider.rs b/libs/api/src/client/provider.rs index 67023c92..5b4db8bc 100644 --- a/libs/api/src/client/provider.rs +++ b/libs/api/src/client/provider.rs @@ -16,16 +16,15 @@ use async_trait::async_trait; use futures_util::Stream; use reqwest::header::HeaderMap; use rmcp::model::Content; +use stakai::Model; use stakpak_shared::hooks::{HookContext, LifecycleEvent}; -use stakpak_shared::models::integrations::anthropic::AnthropicModel; use stakpak_shared::models::integrations::openai::{ - AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice, + ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice, ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool, }; use stakpak_shared::models::llm::{ - GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMStreamInput, + GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMStreamInput, }; -use stakpak_shared::models::stakai_adapter::get_stakai_model_string; use std::pin::Pin; use tokio::sync::mpsc; use uuid::Uuid; @@ -337,7 +336,7 @@ impl AgentProvider for AgentClient { async fn chat_completion( &self, - model: AgentModel, + model: Model, messages: Vec, tools: Option>, ) -> Result { @@ -380,7 +379,7 @@ impl AgentProvider for AgentClient { .state .llm_input .as_ref() - .map(|llm_input| llm_input.model.clone().to_string()) + .map(|llm_input| llm_input.model.id.clone()) .unwrap_or_default(), choices: vec![ChatCompletionChoice { index: 0, @@ -401,7 +400,7 @@ impl AgentProvider for AgentClient { async fn chat_completion_stream( &self, - model: AgentModel, + model: Model, messages: Vec, tools: Option>, _headers: Option, @@ -672,6 +671,60 @@ impl AgentProvider for AgentClient { Err("Slack integration requires Stakpak API key".to_string()) } } + + // ========================================================================= + // Models + // ========================================================================= + + async fn list_models(&self) -> Vec { + // Return all known static models directly + // No network calls - this should always be fast + use stakai::providers::{anthropic, gemini, openai}; + + let mut models = Vec::new(); + + // When using Stakpak API, models are routed through Stakpak + if self.has_stakpak() { + // Add all models with stakpak routing prefix + for model in anthropic::models::models() { + models.push(stakai::Model { + id: format!("anthropic/{}", model.id), + name: model.name, + provider: "stakpak".into(), + reasoning: model.reasoning, + cost: model.cost, + limit: model.limit, + }); + } + for model in openai::models::models() { + models.push(stakai::Model { + id: format!("openai/{}", model.id), + name: model.name, + provider: "stakpak".into(), + reasoning: model.reasoning, + cost: model.cost, + limit: model.limit, + }); + } + for model in gemini::models::models() { + models.push(stakai::Model { + id: format!("google/{}", model.id), + name: model.name, + provider: "stakpak".into(), + reasoning: model.reasoning, + cost: model.cost, + limit: model.limit, + }); + } + } else { + // Direct provider access - return models grouped by provider + models.extend(anthropic::models::models()); + models.extend(openai::models::models()); + models.extend(gemini::models::models()); + } + + models + } } // ============================================================================= @@ -958,31 +1011,15 @@ impl AgentClient { /// Generate a title for a new session async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result { - let llm_model = if let Some(eco_model) = &self.model_options.eco_model { - eco_model.clone() - } else { - // Try to find a suitable model - LLMModel::Anthropic(AnthropicModel::Claude45Haiku) - }; - - // If Stakpak is available, route through it - let model = if self.has_stakpak() { - // Get properly formatted model string with provider prefix (e.g., "anthropic/claude-haiku-4-5") - let model_str = get_stakai_model_string(&llm_model); - // Extract display name from the last segment for UI - let display_name = model_str - .rsplit('/') - .next() - .unwrap_or(&model_str) - .to_string(); - LLMModel::Custom { - provider: "stakpak".to_string(), - model: model_str, - name: Some(display_name), - } - } else { - llm_model - }; + // Use a default haiku model for title generation + let model = Model::new( + "claude-haiku-4-5-20250929", + "Claude Haiku 4.5", + "anthropic", + false, + None, + stakai::ModelLimit::default(), + ); let llm_messages = vec![ LLMMessage { diff --git a/libs/api/src/lib.rs b/libs/api/src/lib.rs index a3676d42..450464b0 100644 --- a/libs/api/src/lib.rs +++ b/libs/api/src/lib.rs @@ -4,7 +4,7 @@ use models::*; use reqwest::header::HeaderMap; use rmcp::model::Content; use stakpak_shared::models::integrations::openai::{ - AgentModel, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, Tool, + ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, Tool, }; use uuid::Uuid; @@ -21,6 +21,63 @@ pub use client::{ AgentClient, AgentClientConfig, DEFAULT_STAKPAK_ENDPOINT, ModelOptions, StakpakConfig, }; +// Re-export Model types from stakai +pub use stakai::{Model, ModelCost, ModelLimit}; + +/// Find a model by ID string +/// +/// Parses the model string and searches the catalog: +/// - If format is "provider/model_id", searches within that provider +/// - If no provider prefix, searches all providers by model ID +/// +/// When `use_stakpak` is true, the returned model will have provider set to "stakpak" +/// for routing through the Stakpak API. +/// +/// Returns None if the model is not found in any catalog. +pub fn find_model(model_str: &str, use_stakpak: bool) -> Option { + use stakai::providers::{anthropic, gemini, openai}; + + // Split on first '/' to check for provider prefix + let (provider_prefix, model_id) = if let Some(idx) = model_str.find('/') { + let (prefix, rest) = model_str.split_at(idx); + (Some(prefix), &rest[1..]) // Skip the '/' + } else { + (None, model_str) + }; + + // Search for the model + let found_model = match provider_prefix { + Some("anthropic") => anthropic::models::get_model(model_id), + Some("openai") => openai::models::get_model(model_id), + Some("google") | Some("gemini") => gemini::models::get_model(model_id), + Some(_) => None, // Unknown provider prefix, will search all below + None => None, + }; + + // If not found with prefix, or no prefix given, search all providers + let found_model = found_model.or_else(|| { + anthropic::models::get_model(model_id) + .or_else(|| openai::models::get_model(model_id)) + .or_else(|| gemini::models::get_model(model_id)) + }); + + // Adjust the model for Stakpak routing if needed + found_model.map(|mut m| { + if use_stakpak { + // Prefix the ID with the original provider for Stakpak routing + let provider_for_id = match m.provider.as_str() { + "anthropic" => "anthropic", + "openai" => "openai", + "google" => "google", + _ => &m.provider, + }; + m.id = format!("{}/{}", provider_for_id, m.id); + m.provider = "stakpak".into(); + } + m + }) +} + #[async_trait] pub trait AgentProvider: Send + Sync { // Account @@ -56,13 +113,13 @@ pub trait AgentProvider: Send + Sync { // Chat async fn chat_completion( &self, - model: AgentModel, + model: Model, messages: Vec, tools: Option>, ) -> Result; async fn chat_completion_stream( &self, - model: AgentModel, + model: Model, messages: Vec, tools: Option>, headers: Option, @@ -97,4 +154,7 @@ pub trait AgentProvider: Send + Sync { &self, input: &SlackSendMessageRequest, ) -> Result, String>; + + // Models + async fn list_models(&self) -> Vec; } diff --git a/libs/api/src/local/hooks/file_scratchpad_context/mod.rs b/libs/api/src/local/hooks/file_scratchpad_context/mod.rs index 2f507d88..7f19e992 100644 --- a/libs/api/src/local/hooks/file_scratchpad_context/mod.rs +++ b/libs/api/src/local/hooks/file_scratchpad_context/mod.rs @@ -1,7 +1,6 @@ use crate::local::context_managers::file_scratchpad_context_manager::{ FileScratchpadContextManager, FileScratchpadContextManagerOptions, }; -use crate::local::{ModelOptions, ModelSet}; use crate::models::AgentState; use stakpak_shared::define_hook; use stakpak_shared::hooks::{Hook, HookAction, HookContext, HookError, LifecycleEvent}; @@ -13,12 +12,10 @@ const SCRATCHPAD_FILE: &str = ".stakpak/session/scratchpad.md"; const TODO_FILE: &str = ".stakpak/session/todo.md"; pub struct FileScratchpadContextHook { - pub model_set: ModelSet, pub context_manager: FileScratchpadContextManager, } pub struct FileScratchpadContextHookOptions { - pub model_options: ModelOptions, pub scratchpad_path: Option, pub todo_path: Option, pub history_action_message_size_limit: Option, @@ -29,7 +26,6 @@ pub struct FileScratchpadContextHookOptions { impl FileScratchpadContextHook { pub fn new(options: FileScratchpadContextHookOptions) -> Self { - let model_set: ModelSet = options.model_options.into(); let context_manager = FileScratchpadContextManager::new(FileScratchpadContextManagerOptions { scratchpad_file_path: options @@ -49,10 +45,7 @@ impl FileScratchpadContextHook { overwrite_if_different: options.overwrite_if_different.unwrap_or(true), }); - Self { - model_set, - context_manager, - } + Self { context_manager } } } @@ -64,7 +57,7 @@ define_hook!( return Ok(HookAction::Continue); } - let model = self.model_set.get_model(&ctx.state.agent_model); + let model = ctx.state.active_model.clone(); let tools = ctx .state diff --git a/libs/api/src/local/hooks/inline_scratchpad_context/mod.rs b/libs/api/src/local/hooks/inline_scratchpad_context/mod.rs index 4fdce628..11e24ab5 100644 --- a/libs/api/src/local/hooks/inline_scratchpad_context/mod.rs +++ b/libs/api/src/local/hooks/inline_scratchpad_context/mod.rs @@ -7,17 +7,14 @@ use crate::local::context_managers::ContextManager; use crate::local::context_managers::scratchpad_context_manager::{ ScratchpadContextManager, ScratchpadContextManagerOptions, }; -use crate::local::{ModelOptions, ModelSet}; use crate::models::AgentState; const SYSTEM_PROMPT: &str = include_str!("./system_prompt.txt"); pub struct InlineScratchpadContextHook { - pub model_set: ModelSet, pub context_manager: ScratchpadContextManager, } pub struct InlineScratchpadContextHookOptions { - pub model_options: ModelOptions, pub history_action_message_size_limit: Option, pub history_action_message_keep_last_n: Option, pub history_action_result_keep_last_n: Option, @@ -25,8 +22,6 @@ pub struct InlineScratchpadContextHookOptions { impl InlineScratchpadContextHook { pub fn new(options: InlineScratchpadContextHookOptions) -> Self { - let model_set: ModelSet = options.model_options.into(); - let context_manager = ScratchpadContextManager::new(ScratchpadContextManagerOptions { history_action_message_size_limit: options .history_action_message_size_limit @@ -39,10 +34,7 @@ impl InlineScratchpadContextHook { .unwrap_or(50), }); - Self { - model_set, - context_manager, - } + Self { context_manager } } } @@ -54,7 +46,7 @@ define_hook!( return Ok(HookAction::Continue); } - let model = self.model_set.get_model(&ctx.state.agent_model); + let model = ctx.state.active_model.clone(); let tools = ctx .state diff --git a/libs/api/src/local/mod.rs b/libs/api/src/local/mod.rs index b2d89e1c..f5d631b7 100644 --- a/libs/api/src/local/mod.rs +++ b/libs/api/src/local/mod.rs @@ -3,10 +3,6 @@ //! This module provides: //! - Database operations for local session storage //! - Lifecycle hooks for context management -//! - Model configuration types - -use stakpak_shared::models::integrations::openai::AgentModel; -use stakpak_shared::models::llm::LLMModel; // Sub-modules pub(crate) mod context_managers; @@ -15,51 +11,3 @@ pub mod hooks; #[cfg(test)] mod tests; - -/// Model options for the agent -#[derive(Clone, Debug, Default)] -pub struct ModelOptions { - pub smart_model: Option, - pub eco_model: Option, - pub recovery_model: Option, -} - -/// Resolved model set with default fallbacks -#[derive(Clone, Debug)] -pub struct ModelSet { - pub smart_model: LLMModel, - pub eco_model: LLMModel, - pub recovery_model: LLMModel, -} - -impl ModelSet { - /// Get the model for a given agent model type - pub fn get_model(&self, agent_model: &AgentModel) -> LLMModel { - match agent_model { - AgentModel::Smart => self.smart_model.clone(), - AgentModel::Eco => self.eco_model.clone(), - AgentModel::Recovery => self.recovery_model.clone(), - } - } -} - -impl From for ModelSet { - fn from(value: ModelOptions) -> Self { - // Default models route through Stakpak provider - let smart_model = value - .smart_model - .unwrap_or_else(|| LLMModel::from("stakpak/anthropic/claude-opus-4-5".to_string())); - let eco_model = value - .eco_model - .unwrap_or_else(|| LLMModel::from("stakpak/anthropic/claude-haiku-4-5".to_string())); - let recovery_model = value - .recovery_model - .unwrap_or_else(|| LLMModel::from("stakpak/openai/gpt-4o".to_string())); - - Self { - smart_model, - eco_model, - recovery_model, - } - } -} diff --git a/libs/api/src/models.rs b/libs/api/src/models.rs index 8c5a59e8..06c697bb 100644 --- a/libs/api/src/models.rs +++ b/libs/api/src/models.rs @@ -4,10 +4,9 @@ use chrono::{DateTime, Utc}; use rmcp::model::Content; use serde::{Deserialize, Serialize}; use serde_json::Value; +use stakai::Model; use stakpak_shared::models::{ - integrations::openai::{ - AgentModel, ChatMessage, FunctionCall, MessageContent, Role, Tool, ToolCall, - }, + integrations::openai::{ChatMessage, FunctionCall, MessageContent, Role, Tool, ToolCall}, llm::{LLMInput, LLMMessage, LLMMessageContent, LLMMessageTypedContent, LLMTokenUsage}, }; use uuid::Uuid; @@ -654,7 +653,8 @@ pub struct SlackSendMessageRequest { #[derive(Debug, Clone, Default, Serialize)] pub struct AgentState { - pub agent_model: AgentModel, + /// The active model to use for inference + pub active_model: Model, pub messages: Vec, pub tools: Option>, @@ -721,13 +721,9 @@ impl From<&LLMOutput> for ChatMessage { } impl AgentState { - pub fn new( - agent_model: AgentModel, - messages: Vec, - tools: Option>, - ) -> Self { + pub fn new(active_model: Model, messages: Vec, tools: Option>) -> Self { Self { - agent_model, + active_model, messages, tools, llm_input: None, @@ -744,8 +740,8 @@ impl AgentState { self.tools = tools; } - pub fn set_agent_model(&mut self, agent_model: AgentModel) { - self.agent_model = agent_model; + pub fn set_active_model(&mut self, model: Model) { + self.active_model = model; } pub fn set_llm_input(&mut self, llm_input: Option) { diff --git a/libs/shared/src/models/integrations/openai.rs b/libs/shared/src/models/integrations/openai.rs index 896c4b6d..3f3a963c 100644 --- a/libs/shared/src/models/integrations/openai.rs +++ b/libs/shared/src/models/integrations/openai.rs @@ -213,38 +213,6 @@ impl std::fmt::Display for OpenAIModel { } } -/// Agent model type (smart/eco/recovery) -#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)] -pub enum AgentModel { - #[serde(rename = "smart")] - #[default] - Smart, - #[serde(rename = "eco")] - Eco, - #[serde(rename = "recovery")] - Recovery, -} - -impl std::fmt::Display for AgentModel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - AgentModel::Smart => write!(f, "smart"), - AgentModel::Eco => write!(f, "eco"), - AgentModel::Recovery => write!(f, "recovery"), - } - } -} - -impl From for AgentModel { - fn from(value: String) -> Self { - match value.as_str() { - "eco" => AgentModel::Eco, - "recovery" => AgentModel::Recovery, - _ => AgentModel::Smart, - } - } -} - // ============================================================================= // Message Types (used by TUI) // ============================================================================= @@ -974,7 +942,7 @@ mod tests { #[test] fn test_serialize_basic_request() { let request = ChatCompletionRequest { - model: AgentModel::Smart.to_string(), + model: "gpt-4".to_string(), messages: vec![ ChatMessage { role: Role::System, @@ -1016,7 +984,7 @@ mod tests { }; let json = serde_json::to_string(&request).unwrap(); - assert!(json.contains("\"model\":\"smart\"")); + assert!(json.contains("\"model\":\"gpt-4\"")); assert!(json.contains("\"messages\":[")); assert!(json.contains("\"role\":\"system\"")); } diff --git a/libs/shared/src/models/llm.rs b/libs/shared/src/models/llm.rs index 21bec653..d2af7e0e 100644 --- a/libs/shared/src/models/llm.rs +++ b/libs/shared/src/models/llm.rs @@ -55,6 +55,7 @@ use crate::models::{ model_pricing::{ContextAware, ModelContextInfo}, }; use serde::{Deserialize, Serialize}; +use stakai::Model; use std::collections::HashMap; use std::fmt::Display; @@ -527,7 +528,7 @@ pub struct LLMGoogleOptions { #[derive(Clone, Debug, Serialize)] pub struct LLMInput { - pub model: LLMModel, + pub model: Model, pub messages: Vec, pub max_tokens: u32, pub tools: Option>, @@ -540,7 +541,7 @@ pub struct LLMInput { #[derive(Debug)] pub struct LLMStreamInput { - pub model: LLMModel, + pub model: Model, pub messages: Vec, pub max_tokens: u32, pub stream_channel_tx: tokio::sync::mpsc::Sender, diff --git a/libs/shared/src/models/stakai_adapter.rs b/libs/shared/src/models/stakai_adapter.rs index 0d0ebf32..bad6ddf3 100644 --- a/libs/shared/src/models/stakai_adapter.rs +++ b/libs/shared/src/models/stakai_adapter.rs @@ -6,15 +6,15 @@ use crate::models::error::{AgentError, BadRequestErrorMessage}; use crate::models::llm::{ GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMInput, - LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent, LLMModel, + LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent, LLMProviderConfig, LLMProviderOptions, LLMStreamInput, LLMTokenUsage, LLMTool, ProviderConfig, }; use futures::StreamExt; use stakai::{ AnthropicOptions, ContentPart, FinishReason, GenerateOptions, GenerateRequest, GenerateResponse, GoogleOptions, Headers, Inference, InferenceConfig, Message, MessageContent, - OpenAIOptions, ProviderOptions, ReasoningEffort, Role, StreamEvent, ThinkingOptions, Tool, - ToolFunction, Usage, providers::anthropic::AnthropicConfig as StakaiAnthropicConfig, + Model, OpenAIOptions, ProviderOptions, ReasoningEffort, Role, StreamEvent, ThinkingOptions, + Tool, ToolFunction, Usage, providers::anthropic::AnthropicConfig as StakaiAnthropicConfig, registry::ProviderRegistry, }; @@ -230,11 +230,11 @@ pub fn finish_reason_to_string(reason: &FinishReason) -> String { /// Convert CLI LLMProviderOptions to StakAI ProviderOptions pub fn to_stakai_provider_options( opts: &LLMProviderOptions, - model: &LLMModel, + model: &Model, ) -> Option { - // Determine provider from model and options - match model { - LLMModel::Anthropic(_) => { + // Determine provider from model's provider field + match model.provider.as_str() { + "anthropic" => { if let Some(anthropic) = &opts.anthropic { let thinking = anthropic .thinking @@ -249,7 +249,7 @@ pub fn to_stakai_provider_options( None } } - LLMModel::OpenAI(_) => { + "openai" => { if let Some(openai) = &opts.openai { let reasoning_effort = openai.reasoning_effort.as_ref().and_then(|e| { match e.to_lowercase().as_str() { @@ -273,14 +273,14 @@ pub fn to_stakai_provider_options( None } } - LLMModel::Gemini(_) => opts.google.as_ref().map(|google| { + "google" | "gemini" => opts.google.as_ref().map(|google| { ProviderOptions::Google(GoogleOptions { thinking_budget: google.thinking_budget, cached_content: None, }) }), - LLMModel::Custom { .. } => { - // For custom models, try to infer from which options are set + _ => { + // For custom/unknown providers, try to infer from which options are set if let Some(anthropic) = &opts.anthropic { let thinking = anthropic .thinking @@ -554,16 +554,9 @@ fn build_provider_registry_direct(config: &LLMProviderConfig) -> Result String { - match model { - LLMModel::Anthropic(m) => format!("anthropic/{}", m), - LLMModel::Gemini(m) => format!("google/{}", m), - LLMModel::OpenAI(m) => format!("openai/{}", m), - LLMModel::Custom { - provider, model, .. - } => format!("{}/{}", provider, model), - } +/// Get model string for StakAI +pub fn get_stakai_model_string(model: &Model) -> String { + model.id.clone() } /// Wrapper around StakAI Inference for CLI usage @@ -603,7 +596,6 @@ impl StakAIClient { /// Non-streaming chat completion pub async fn chat(&self, input: LLMInput) -> Result { - let model_string = get_stakai_model_string(&input.model); let messages: Vec = input.messages.iter().map(to_stakai_message).collect(); let mut options = GenerateOptions::new().max_tokens(input.max_tokens); @@ -630,7 +622,7 @@ impl StakAIClient { .and_then(|opts| to_stakai_provider_options(opts, &input.model)); let request = GenerateRequest { - model: model_string.clone(), + model: input.model.clone(), messages, options, provider_options, @@ -641,7 +633,7 @@ impl StakAIClient { AgentError::BadRequest(BadRequestErrorMessage::InvalidAgentInput(e.to_string())) })?; - Ok(from_stakai_response(response, &model_string)) + Ok(from_stakai_response(response, &input.model.id)) } /// Streaming chat completion @@ -649,7 +641,6 @@ impl StakAIClient { &self, input: LLMStreamInput, ) -> Result { - let model_string = get_stakai_model_string(&input.model); let messages: Vec = input.messages.iter().map(to_stakai_message).collect(); let mut options = GenerateOptions::new().max_tokens(input.max_tokens); @@ -675,8 +666,9 @@ impl StakAIClient { .as_ref() .and_then(|opts| to_stakai_provider_options(opts, &input.model)); + let model_id = input.model.id.clone(); let request = GenerateRequest { - model: model_string.clone(), + model: input.model.clone(), messages, options, provider_options, @@ -814,7 +806,7 @@ impl StakAIClient { Ok(LLMCompletionResponse { id: uuid::Uuid::new_v4().to_string(), - model: model_string, + model: model_id, object: "chat.completion".to_string(), choices: vec![LLMChoice { finish_reason: Some(finish_reason), @@ -828,14 +820,16 @@ impl StakAIClient { usage: Some(final_usage), }) } + + /// Get the provider registry for model listing + pub fn registry(&self) -> &ProviderRegistry { + self.inference.registry() + } } #[cfg(test)] mod tests { use super::*; - use crate::models::integrations::anthropic::AnthropicModel; - use crate::models::integrations::gemini::GeminiModel; - use crate::models::integrations::openai::OpenAIModel; // ==================== Role Conversion Tests ==================== @@ -1307,37 +1301,30 @@ mod tests { #[test] fn test_model_string_anthropic() { - let model = LLMModel::Anthropic(AnthropicModel::Claude45Sonnet); + let model = Model::custom("claude-sonnet-4-5-20250929", "anthropic"); let model_str = get_stakai_model_string(&model); - assert!(model_str.starts_with("anthropic/")); - assert!(model_str.contains("claude")); + assert_eq!(model_str, "claude-sonnet-4-5-20250929"); } #[test] fn test_model_string_openai() { - let model = LLMModel::OpenAI(OpenAIModel::GPT5); + let model = Model::custom("gpt-5", "openai"); let model_str = get_stakai_model_string(&model); - assert!(model_str.starts_with("openai/")); - assert!(model_str.contains("gpt")); + assert_eq!(model_str, "gpt-5"); } #[test] fn test_model_string_gemini() { - let model = LLMModel::Gemini(GeminiModel::Gemini25Flash); + let model = Model::custom("gemini-2.5-flash", "google"); let model_str = get_stakai_model_string(&model); - assert!(model_str.starts_with("google/")); - assert!(model_str.contains("gemini")); + assert_eq!(model_str, "gemini-2.5-flash"); } #[test] fn test_model_string_custom() { - let model = LLMModel::Custom { - provider: "litellm".to_string(), - model: "claude-opus-4-5".to_string(), - name: None, - }; + let model = Model::custom("claude-opus-4-5", "litellm"); let model_str = get_stakai_model_string(&model); - assert_eq!(model_str, "litellm/claude-opus-4-5"); + assert_eq!(model_str, "claude-opus-4-5"); } // ==================== Response Conversion Tests ==================== @@ -1437,7 +1424,7 @@ mod tests { google: None, }; - let model = LLMModel::Anthropic(AnthropicModel::Claude45Sonnet); + let model = Model::custom("claude-sonnet-4-5-20250929", "anthropic"); let result = to_stakai_provider_options(&opts, &model); assert!(result.is_some()); @@ -1461,7 +1448,7 @@ mod tests { google: None, }; - let model = LLMModel::OpenAI(OpenAIModel::GPT5); + let model = Model::custom("gpt-5", "openai"); let result = to_stakai_provider_options(&opts, &model); assert!(result.is_some()); @@ -1484,7 +1471,7 @@ mod tests { }), }; - let model = LLMModel::Gemini(GeminiModel::Gemini25Flash); + let model = Model::custom("gemini-2.5-flash", "google"); let result = to_stakai_provider_options(&opts, &model); assert!(result.is_some()); @@ -1501,7 +1488,7 @@ mod tests { let opts = LLMProviderOptions::default(); - let model = LLMModel::Anthropic(AnthropicModel::Claude45Sonnet); + let model = Model::custom("claude-sonnet-4-5-20250929", "anthropic"); let result = to_stakai_provider_options(&opts, &model); assert!(result.is_none()); diff --git a/tui/Cargo.toml b/tui/Cargo.toml index ae4941d7..22670314 100644 --- a/tui/Cargo.toml +++ b/tui/Cargo.toml @@ -8,6 +8,7 @@ repository = { workspace = true } homepage = { workspace = true } [dependencies] +stakai = { workspace = true } stakpak-shared = { workspace = true } stakpak-api = { workspace = true } regex = { workspace = true } diff --git a/tui/src/app.rs b/tui/src/app.rs index 440ff381..38c1295a 100644 --- a/tui/src/app.rs +++ b/tui/src/app.rs @@ -2,6 +2,7 @@ mod events; mod types; pub use events::{InputEvent, OutputEvent}; +use stakai::Model; use stakpak_shared::models::llm::{LLMModel, LLMTokenUsage}; pub use types::*; @@ -22,7 +23,7 @@ use crate::services::textarea::{TextArea, TextAreaState}; use ratatui::layout::Size; use ratatui::text::Line; use stakpak_api::models::ListRuleBook; -use stakpak_shared::models::integrations::openai::{AgentModel, ToolCall, ToolCallResult}; +use stakpak_shared::models::integrations::openai::{ToolCall, ToolCallResult}; use stakpak_shared::secret_manager::SecretManager; use std::collections::HashMap; use tokio::sync::mpsc; @@ -162,6 +163,12 @@ pub struct AppState { pub filtered_rulebooks: Vec, pub rulebook_config: Option, + // ========== Model Switcher State ========== + pub show_model_switcher: bool, + pub available_models: Vec, + pub model_switcher_selected: usize, + pub current_model: Option, + // ========== Command Palette State ========== pub show_command_palette: bool, pub command_palette_selected: usize, @@ -190,7 +197,7 @@ pub struct AppState { pub is_git_repo: bool, pub auto_approve_manager: AutoApproveManager, pub allowed_tools: Option>, - pub agent_model: AgentModel, + pub model: Model, pub llm_model: Option, /// Auth display info: (config_provider, auth_provider, subscription_name) for local providers pub auth_display_info: (Option, Option, Option), @@ -240,7 +247,7 @@ pub struct AppStateOptions<'a> { pub auto_approve_tools: Option<&'a Vec>, pub allowed_tools: Option<&'a Vec>, pub input_tx: Option>, - pub agent_model: AgentModel, + pub model: Model, pub editor_command: Option, /// Auth display info: (config_provider, auth_provider, subscription_name) for local providers pub auth_display_info: (Option, Option, Option), @@ -282,7 +289,7 @@ impl AppState { auto_approve_tools, allowed_tools, input_tx, - agent_model, + model, editor_command, auth_display_info, } = options; @@ -422,6 +429,12 @@ impl AppState { rulebook_switcher_selected: 0, rulebook_search_input: String::new(), filtered_rulebooks: Vec::new(), + + // Model switcher initialization + show_model_switcher: false, + available_models: Vec::new(), + model_switcher_selected: 0, + current_model: None, // Command palette initialization show_command_palette: false, command_palette_selected: 0, @@ -448,7 +461,7 @@ impl AppState { prompt_tokens_details: None, }, context_usage_percent: 0, - agent_model, + model, llm_model: None, // Side panel initialization diff --git a/tui/src/app/events.rs b/tui/src/app/events.rs index 1c23044b..e0a7f701 100644 --- a/tui/src/app/events.rs +++ b/tui/src/app/events.rs @@ -1,7 +1,8 @@ use ratatui::style::Color; +use stakai::Model; use stakpak_api::models::ListRuleBook; use stakpak_shared::models::{ - integrations::openai::{AgentModel, ToolCall, ToolCallResult, ToolCallResultProgress}, + integrations::openai::{ToolCall, ToolCallResult, ToolCallResultProgress}, llm::{LLMModel, LLMTokenUsage}, }; use uuid::Uuid; @@ -136,6 +137,12 @@ pub enum InputEvent { // Model events StreamModel(LLMModel), + // Model switcher events + ShowModelSwitcher, + AvailableModelsLoaded(Vec), + ModelSwitcherSelect, + ModelSwitcherCancel, + // Side panel events ToggleSidePanel, SidePanelNextSection, @@ -162,5 +169,6 @@ pub enum OutputEvent { RequestRulebookUpdate(Vec), RequestCurrentRulebooks, RequestTotalUsage, - SwitchModel(AgentModel), + RequestAvailableModels, + SwitchToModel(Model), } diff --git a/tui/src/event_loop.rs b/tui/src/event_loop.rs index 4b468ffa..2726e2d1 100644 --- a/tui/src/event_loop.rs +++ b/tui/src/event_loop.rs @@ -14,7 +14,8 @@ use crossterm::event::{ }; use crossterm::{execute, terminal::EnterAlternateScreen}; use ratatui::{Terminal, backend::CrosstermBackend}; -use stakpak_shared::models::integrations::openai::{AgentModel, ToolCallResultStatus}; +use stakai::Model; +use stakpak_shared::models::integrations::openai::ToolCallResultStatus; use std::io; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; @@ -48,7 +49,7 @@ pub async fn run_tui( allowed_tools: Option<&Vec>, current_profile_name: String, rulebook_config: Option, - agent_model: AgentModel, + model: Model, editor_command: Option, auth_display_info: (Option, Option, Option), ) -> io::Result<()> { @@ -90,7 +91,7 @@ pub async fn run_tui( auto_approve_tools, allowed_tools, input_tx: Some(internal_tx.clone()), - agent_model, + model, editor_command, auth_display_info, }); diff --git a/tui/src/services/commands.rs b/tui/src/services/commands.rs index 5d97b4e4..6756617e 100644 --- a/tui/src/services/commands.rs +++ b/tui/src/services/commands.rs @@ -13,8 +13,8 @@ use crate::constants::SUMMARIZE_PROMPT_BASE; use crate::services::auto_approve::AutoApprovePolicy; use crate::services::helper_block::{ push_clear_message, push_error_message, push_help_message, push_issue_message, - push_memorize_message, push_model_message, push_status_message, push_styled_message, - push_support_message, push_usage_message, render_system_message, welcome_messages, + push_memorize_message, push_status_message, push_styled_message, push_support_message, + push_usage_message, render_system_message, welcome_messages, }; use crate::services::message::{Message, MessageContent}; use crate::{InputEvent, OutputEvent}; @@ -25,7 +25,7 @@ use ratatui::{ text::{Line, Span, Text}, widgets::{Block, Borders, Paragraph}, }; -use stakpak_shared::models::integrations::openai::AgentModel; + use stakpak_shared::models::llm::LLMTokenUsage; use stakpak_shared::models::model_pricing::ContextAware; use tokio::sync::mpsc::Sender; @@ -185,8 +185,8 @@ pub fn get_all_commands() -> Vec { ), Command::new( "Switch Model", - "Switch model (smart/eco)", - "/model", + "Switch to a different AI model", + "Ctrl+M", CommandAction::SwitchModel, ), ] @@ -201,7 +201,7 @@ pub fn commands_to_helper_commands() -> Vec { }, HelperCommand { command: "/model", - description: "Switch model (smart/eco)", + description: "Open model switcher to change AI model", }, HelperCommand { command: "/clear", @@ -301,21 +301,16 @@ pub fn execute_command(command_id: CommandId, ctx: CommandContext) -> Result<(), ctx.state.show_helper_dropdown = false; Ok(()) } - "/model" => match switch_model(ctx.state) { - Ok(()) => { - let _ = ctx - .output_tx - .try_send(OutputEvent::SwitchModel(ctx.state.agent_model.clone())); - push_model_message(ctx.state); - ctx.state.text_area.set_text(""); - ctx.state.show_helper_dropdown = false; - Ok(()) - } - Err(e) => { - push_error_message(ctx.state, &e, None); - Err(e) - } - }, + "/model" => { + // Show model switcher popup + ctx.state.show_model_switcher = true; + ctx.state.model_switcher_selected = 0; + ctx.state.text_area.set_text(""); + ctx.state.show_helper_dropdown = false; + // Request available models from the output handler + let _ = ctx.output_tx.try_send(OutputEvent::RequestAvailableModels); + Ok(()) + } "/clear" => { push_clear_message(ctx.state); ctx.state.text_area.set_text(""); @@ -479,26 +474,6 @@ pub fn execute_command(command_id: CommandId, ctx: CommandContext) -> Result<(), // ========== Helper Functions ========== -pub fn switch_model(state: &mut AppState) -> Result<(), String> { - match state.agent_model { - AgentModel::Smart => { - // TODO: Check if context exceeds eco model context window size - state.agent_model = AgentModel::Eco; - Ok(()) - } - AgentModel::Eco => { - // TODO: Check if context exceeds smart model context window size - state.agent_model = AgentModel::Smart; - Ok(()) - } - AgentModel::Recovery => { - // TODO: Check if context exceeds recovery model context window size - state.agent_model = AgentModel::Smart; - Ok(()) - } - } -} - /// Terminate any active shell command before switching sessions fn terminate_active_shell(state: &mut AppState) { if let Some(cmd) = &state.active_shell_command { diff --git a/tui/src/services/handlers/mod.rs b/tui/src/services/handlers/mod.rs index 9d057acc..fcab3d7c 100644 --- a/tui/src/services/handlers/mod.rs +++ b/tui/src/services/handlers/mod.rs @@ -47,6 +47,7 @@ pub fn update( | InputEvent::ProfileSwitchFailed(_) | InputEvent::RulebooksLoaded(_) | InputEvent::CurrentRulebooksLoaded(_) + | InputEvent::AvailableModelsLoaded(_) | InputEvent::Quit | InputEvent::AttemptQuit => { // Allow these events through @@ -111,6 +112,41 @@ pub fn update( } } + // Intercept keys for Model Switcher Popup + if state.show_model_switcher { + match event { + InputEvent::HandleEsc => { + popup::handle_model_switcher_cancel(state); + return; + } + InputEvent::Up | InputEvent::ScrollUp => { + // Navigate up in model list + if state.model_switcher_selected > 0 { + state.model_switcher_selected -= 1; + } + return; + } + InputEvent::Down | InputEvent::ScrollDown => { + // Navigate down in model list + if state.model_switcher_selected < state.available_models.len().saturating_sub(1) { + state.model_switcher_selected += 1; + } + return; + } + InputEvent::InputSubmitted => { + popup::handle_model_switcher_select(state, output_tx); + return; + } + InputEvent::AvailableModelsLoaded(_) => { + // Let this fall through to the main handler + } + _ => { + // Consume other events to prevent side effects + return; + } + } + } + // Intercept keys for Approval Bar (inline approval) // Controls: ←→ navigate, Space toggle, Enter confirm all, Esc reject all // Don't intercept if collapsed messages popup is showing @@ -630,6 +666,20 @@ pub fn update( popup::handle_toggle_more_shortcuts(state); } + // Model switcher handlers + InputEvent::ShowModelSwitcher => { + popup::handle_show_model_switcher(state, output_tx); + } + InputEvent::AvailableModelsLoaded(models) => { + popup::handle_available_models_loaded(state, models); + } + InputEvent::ModelSwitcherSelect => { + popup::handle_model_switcher_select(state, output_tx); + } + InputEvent::ModelSwitcherCancel => { + popup::handle_model_switcher_cancel(state); + } + // Side panel handlers InputEvent::ToggleSidePanel => { popup::handle_toggle_side_panel(state); diff --git a/tui/src/services/handlers/popup.rs b/tui/src/services/handlers/popup.rs index 49134a76..db0c7bef 100644 --- a/tui/src/services/handlers/popup.rs +++ b/tui/src/services/handlers/popup.rs @@ -1,6 +1,6 @@ //! Popup Event Handlers //! -//! Handles all popup-related events including profile switcher, rulebook switcher, command palette, shortcuts, collapsed messages, and context popup. +//! Handles all popup-related events including profile switcher, rulebook switcher, model switcher, command palette, shortcuts, collapsed messages, and context popup. use crate::app::{AppState, OutputEvent}; use crate::services::detect_term::AdaptiveColors; @@ -9,6 +9,7 @@ use crate::services::message::{ Message, get_wrapped_collapsed_message_lines_cached, invalidate_message_lines_cache, }; use ratatui::style::{Color, Style}; +use stakai::Model; use stakpak_api::models::ListRuleBook; use tokio::sync::mpsc::Sender; @@ -830,3 +831,74 @@ pub fn handle_file_changes_popup_mouse_click(state: &mut AppState, col: u16, row } } } + +// ========== Model Switcher Handlers ========== + +/// Handle show model switcher event +pub fn handle_show_model_switcher(state: &mut AppState, output_tx: &Sender) { + // Don't show model switcher if input is blocked or dialog is open + if state.profile_switching_in_progress + || state.is_dialog_open + || state.approval_bar.is_visible() + { + return; + } + + // Clear any pending input + state.text_area.set_text(""); + + // Request available models from the backend + let _ = output_tx.try_send(OutputEvent::RequestAvailableModels); + + state.show_model_switcher = true; + state.model_switcher_selected = 0; +} + +/// Handle available models loaded event +pub fn handle_available_models_loaded(state: &mut AppState, models: Vec) { + state.available_models = models; + + // Pre-select current model if available + if let Some(current) = &state.current_model { + if let Some(idx) = state + .available_models + .iter() + .position(|m| m.id == current.id) + { + state.model_switcher_selected = idx; + } + } +} + +/// Handle model switcher select event +pub fn handle_model_switcher_select(state: &mut AppState, output_tx: &Sender) { + if state.show_model_switcher && !state.available_models.is_empty() { + if state.model_switcher_selected < state.available_models.len() { + let selected_model = state.available_models[state.model_switcher_selected].clone(); + + // Don't switch if already on this model + if state + .current_model + .as_ref() + .is_some_and(|m| m.id == selected_model.id) + { + state.show_model_switcher = false; + return; + } + + // Update current model + state.current_model = Some(selected_model.clone()); + + // Close the switcher + state.show_model_switcher = false; + + // Send request to switch model + let _ = output_tx.try_send(OutputEvent::SwitchToModel(selected_model.clone())); + } + } +} + +/// Handle model switcher cancel event +pub fn handle_model_switcher_cancel(state: &mut AppState) { + state.show_model_switcher = false; +} diff --git a/tui/src/services/helper_block.rs b/tui/src/services/helper_block.rs index 038deaa9..036cd585 100644 --- a/tui/src/services/helper_block.rs +++ b/tui/src/services/helper_block.rs @@ -3,7 +3,6 @@ use crate::constants::{EXCEEDED_API_LIMIT_ERROR, EXCEEDED_API_LIMIT_ERROR_MESSAG use crate::services::message::{Message, MessageContent, invalidate_message_lines_cache}; use ratatui::style::{Color, Modifier, Style}; use ratatui::text::{Line, Span}; -use stakpak_shared::models::integrations::openai::AgentModel; use uuid::Uuid; pub fn get_stakpak_version() -> String { @@ -230,36 +229,6 @@ pub fn push_memorize_message(state: &mut AppState) { invalidate_message_lines_cache(state); } -pub fn push_model_message(state: &mut AppState) { - let mut line = Vec::new(); - line.push(Span::styled( - "Switched to ", - Style::default().fg(Color::DarkGray), - )); - match state.agent_model { - AgentModel::Smart => { - line.push(Span::styled("smart", Style::default().fg(Color::Cyan))); - } - AgentModel::Eco => { - line.push(Span::styled("eco", Style::default().fg(Color::LightGreen))); - } - AgentModel::Recovery => { - line.push(Span::styled( - "recovery", - Style::default().fg(Color::LightBlue), - )); - } - } - line.push(Span::styled(" model", Style::default().fg(Color::DarkGray))); - - state.messages.push(Message { - id: uuid::Uuid::new_v4(), - content: MessageContent::Styled(Line::from(line)), - is_collapsed: None, - }); - invalidate_message_lines_cache(state); -} - pub fn push_help_message(state: &mut AppState) { use ratatui::style::{Color, Modifier, Style}; use ratatui::text::{Line, Span}; diff --git a/tui/src/services/mod.rs b/tui/src/services/mod.rs index bd49e261..1efe6d79 100644 --- a/tui/src/services/mod.rs +++ b/tui/src/services/mod.rs @@ -17,6 +17,7 @@ pub mod image_upload; pub mod markdown_renderer; pub mod message; pub mod message_pattern; +pub mod model_switcher; pub mod placeholder_prompts; pub mod profile_switcher; pub mod rulebook_switcher; diff --git a/tui/src/services/model_switcher.rs b/tui/src/services/model_switcher.rs new file mode 100644 index 00000000..8c2fc71c --- /dev/null +++ b/tui/src/services/model_switcher.rs @@ -0,0 +1,256 @@ +//! Model Switcher UI Component +//! +//! Provides a popup UI for switching between available AI models. +//! Accessible via Ctrl+G or the /model command. + +use crate::app::AppState; +use ratatui::{ + Frame, + layout::{Constraint, Direction, Layout, Rect}, + style::{Color, Modifier, Style}, + text::{Line, Span}, + widgets::{Block, Borders, List, ListItem, ListState, Paragraph}, +}; + +/// Render the model switcher popup +pub fn render_model_switcher_popup(f: &mut Frame, state: &AppState) { + // Calculate popup size (45% width to fit model names and costs, 60% height) + let area = centered_rect(45, 60, f.area()); + + // Clear background + f.render_widget(ratatui::widgets::Clear, area); + + // Show loading message if no models are available yet + if state.available_models.is_empty() { + let block = Block::default() + .borders(Borders::ALL) + .border_style(Style::default().fg(Color::Cyan)) + .title(" Switch Model "); + + let loading = Paragraph::new(vec![ + Line::from(""), + Line::from(Span::styled( + " Loading models...", + Style::default().fg(Color::Yellow), + )), + Line::from(""), + Line::from(Span::styled( + " Press ESC to cancel", + Style::default().fg(Color::DarkGray), + )), + ]); + + f.render_widget(block, area); + let inner = Rect { + x: area.x + 1, + y: area.y + 1, + width: area.width - 2, + height: area.height - 2, + }; + f.render_widget(loading, inner); + return; + } + + // Group models by provider + let mut models_by_provider: std::collections::HashMap<&str, Vec<_>> = + std::collections::HashMap::new(); + for model in &state.available_models { + models_by_provider + .entry(model.provider.as_str()) + .or_default() + .push(model); + } + + // Create list items with provider headers + let mut items: Vec = Vec::new(); + let mut item_indices: Vec> = Vec::new(); // Maps display index to model index + let mut model_idx = 0; + + // Sort providers for consistent ordering + let mut providers: Vec<_> = models_by_provider.keys().collect(); + providers.sort(); + + for provider in providers { + let models = &models_by_provider[provider]; + + // Provider header + let provider_name = match *provider { + "anthropic" => "Anthropic", + "openai" => "OpenAI", + "google" => "Google", + "stakpak" => "Stakpak", + _ => *provider, + }; + + items.push(ListItem::new(Line::from(vec![Span::styled( + format!(" {} ", provider_name), + Style::default() + .fg(Color::Yellow) + .add_modifier(Modifier::BOLD), + )]))); + item_indices.push(None); // Header is not selectable + + // Model items + for model in models.iter() { + let is_selected = model_idx == state.model_switcher_selected; + let is_current = state + .current_model + .as_ref() + .is_some_and(|m| m.id == model.id); + + let mut spans = vec![]; + + // Current indicator + if is_current { + spans.push(Span::styled(" ", Style::default().fg(Color::Green))); + } else { + spans.push(Span::raw(" ")); + } + + // Model name + let name_style = if is_current { + Style::default() + .fg(Color::Green) + .add_modifier(Modifier::BOLD) + } else if is_selected { + Style::default() + .fg(Color::White) + .add_modifier(Modifier::BOLD) + } else { + Style::default().fg(Color::Gray) + }; + spans.push(Span::styled(model.name.clone(), name_style)); + + // Reasoning indicator + if model.reasoning { + spans.push(Span::styled(" [R]", Style::default().fg(Color::Magenta))); + } + + // Cost if available + if let Some(cost) = &model.cost { + spans.push(Span::styled( + format!(" ${:.2}/${:.2}", cost.input, cost.output), + Style::default().fg(Color::DarkGray), + )); + } + + items.push(ListItem::new(Line::from(spans))); + item_indices.push(Some(model_idx)); + model_idx += 1; + } + } + + // Create the main block with border + let block = Block::default() + .borders(Borders::ALL) + .border_style(Style::default().fg(Color::Cyan)); + + // Split area for title, list and help text inside the block + let inner_area = Rect { + x: area.x + 1, + y: area.y + 1, + width: area.width - 2, + height: area.height - 2, + }; + + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([ + Constraint::Length(1), // Title + Constraint::Min(3), // List + Constraint::Length(2), // Help text + ]) + .split(inner_area); + + // Render title inside the popup + let title = " Switch Model"; + let title_style = Style::default() + .fg(Color::Yellow) + .add_modifier(Modifier::BOLD); + let title_line = Line::from(Span::styled(title, title_style)); + let title_paragraph = Paragraph::new(title_line); + + f.render_widget(title_paragraph, chunks[0]); + + // Find the display index that corresponds to the selected model index + let display_selected = item_indices + .iter() + .position(|idx| *idx == Some(state.model_switcher_selected)) + .unwrap_or(0); + + // Create list with proper block and padding + let list = List::new(items) + .highlight_style(Style::default().bg(Color::Cyan).fg(Color::Black)) + .block(Block::default().borders(Borders::NONE)); + + // Create list state for highlighting + let mut list_state = ListState::default(); + list_state.select(Some(display_selected)); + + // Render list with proper padding + let list_area = Rect { + x: chunks[1].x, + y: chunks[1].y + 1, + width: chunks[1].width, + height: chunks[1].height.saturating_sub(1), + }; + + f.render_stateful_widget(list, list_area, &mut list_state); + + // Help text + let help = Paragraph::new(vec![ + Line::from(vec![ + Span::styled("↑/↓", Style::default().fg(Color::DarkGray)), + Span::styled(" navigate", Style::default().fg(Color::Cyan)), + Span::raw(" "), + Span::styled("↵", Style::default().fg(Color::DarkGray)), + Span::styled(" select", Style::default().fg(Color::Cyan)), + Span::raw(" "), + Span::styled("esc", Style::default().fg(Color::DarkGray)), + Span::styled(" cancel", Style::default().fg(Color::Cyan)), + ]), + Line::from(vec![ + Span::styled("[R]", Style::default().fg(Color::Magenta)), + Span::styled(" = reasoning support", Style::default().fg(Color::DarkGray)), + Span::raw(" "), + Span::styled("$in/$out", Style::default().fg(Color::DarkGray)), + Span::styled( + " = cost per 1M tokens", + Style::default().fg(Color::DarkGray), + ), + ]), + ]); + + let help_area = Rect { + x: chunks[2].x + 1, + y: chunks[2].y, + width: chunks[2].width.saturating_sub(2), + height: chunks[2].height, + }; + + f.render_widget(help, help_area); + + // Render the border last (so it's on top) + f.render_widget(block, area); +} + +/// Helper function to create a centered rect +fn centered_rect(percent_x: u16, percent_y: u16, r: Rect) -> Rect { + let popup_layout = Layout::default() + .direction(Direction::Vertical) + .constraints([ + Constraint::Percentage((100 - percent_y) / 2), + Constraint::Percentage(percent_y), + Constraint::Percentage((100 - percent_y) / 2), + ]) + .split(r); + + Layout::default() + .direction(Direction::Horizontal) + .constraints([ + Constraint::Percentage((100 - percent_x) / 2), + Constraint::Percentage(percent_x), + Constraint::Percentage((100 - percent_x) / 2), + ]) + .split(popup_layout[1])[1] +} diff --git a/tui/src/services/side_panel.rs b/tui/src/services/side_panel.rs index 0d523613..92df23bd 100644 --- a/tui/src/services/side_panel.rs +++ b/tui/src/services/side_panel.rs @@ -15,7 +15,6 @@ use ratatui::{ text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, }; -use stakpak_shared::models::model_pricing::ContextAware; /// Left padding for content inside the side panel const LEFT_PADDING: &str = " "; @@ -184,19 +183,16 @@ fn render_context_section(f: &mut Frame, state: &AppState, area: Rect, collapsed ]) }; - // Token usage - let tokens = state.current_message_usage.total_tokens; - let context_info = state - .llm_model - .as_ref() - .map(|m| m.context_info()) - .unwrap_or_default(); - let max_tokens = context_info.max_tokens as u32; + // Get the active model (current_model if set, otherwise default model) + let active_model = state.current_model.as_ref().unwrap_or(&state.model); - // Show N/A when no content yet (tokens == 0) + // Token usage - use session total and active model's context limit + let tokens = state.total_session_usage.total_tokens; + let max_tokens = active_model.limit.context as u32; + + // Show tokens info if tokens == 0 { lines.push(make_row("Tokens", "N/A".to_string(), Color::DarkGray)); - lines.push(make_row("Model", "N/A".to_string(), Color::DarkGray)); } else { let percentage = if max_tokens > 0 { ((tokens as f64 / max_tokens as f64) * 100.0).round() as u32 @@ -214,29 +210,27 @@ fn render_context_section(f: &mut Frame, state: &AppState, area: Rect, collapsed ), Color::White, )); + } - // Model name - let model_name = state - .llm_model - .as_ref() - .map(|m| m.model_name()) - .unwrap_or_else(|| state.agent_model.to_string()); + // Model name - from active model + let model_name = &active_model.name; - // Truncate model name if needed, assuming label len ~10 (" Model:") - let avail_for_model = area.width as usize - 10; - let truncated_model = truncate_string(&model_name, avail_for_model); + // Truncate model name if needed, assuming label len ~10 (" Model:") + let avail_for_model = area.width as usize - 10; + let truncated_model = truncate_string(model_name, avail_for_model); - lines.push(make_row("Model", truncated_model, Color::Cyan)); - } + lines.push(make_row("Model", truncated_model, Color::Cyan)); - // Provider - show subscription, auth provider, or config provider - let provider_value = match &state.auth_display_info { - (_, Some(_), Some(subscription)) => subscription.clone(), - (_, Some(auth_provider), None) => auth_provider.clone(), - (Some(config_provider), None, None) => config_provider.clone(), - _ => "Remote".to_string(), + // Provider - from active model (capitalized) + let provider = { + let p = &active_model.provider; + let mut chars = p.chars(); + match chars.next() { + Some(c) => c.to_uppercase().collect::() + chars.as_str(), + None => String::new(), + } }; - lines.push(make_row("Provider", provider_value, Color::DarkGray)); + lines.push(make_row("Provider", provider, Color::DarkGray)); let paragraph = Paragraph::new(lines); f.render_widget(paragraph, area); diff --git a/tui/src/view.rs b/tui/src/view.rs index 3a5906a1..e89bdbcf 100644 --- a/tui/src/view.rs +++ b/tui/src/view.rs @@ -223,6 +223,11 @@ pub fn view(f: &mut Frame, state: &mut AppState) { crate::services::rulebook_switcher::render_rulebook_switcher_popup(f, state); } + // Render model switcher + if state.show_model_switcher { + crate::services::model_switcher::render_model_switcher_popup(f, state); + } + // Render profile switch overlay if state.profile_switching_in_progress { crate::services::profile_switcher::render_profile_switch_overlay(f, state);