Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 43 additions & 6 deletions cli/src/commands/acp/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ use agent_client_protocol::{self as acp, Client as AcpClient, SessionNotificatio
use futures_util::StreamExt;
use stakpak_api::models::ApiStreamError;
use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, StakpakConfig};
use stakpak_api::{Model, ModelLimit};
use stakpak_mcp_client::McpClient;
use stakpak_shared::models::integrations::mcp::CallToolResultExt;
use stakpak_shared::models::integrations::openai::{
AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamResponse,
ChatMessage, FinishReason, FunctionCall, FunctionCallDelta, MessageContent, Role, Tool,
ToolCall, ToolCallResultProgress, ToolCallResultStatus,
ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
FinishReason, FunctionCall, FunctionCallDelta, MessageContent, Role, Tool, ToolCall,
ToolCallResultProgress, ToolCallResultStatus,
};
use stakpak_shared::models::llm::LLMTokenUsage;
use std::cell::Cell;
Expand All @@ -22,6 +23,8 @@ use uuid::Uuid;
pub struct StakpakAcpAgent {
config: AppConfig,
client: Arc<dyn AgentProvider>,
/// Default model to use for chat completions
model: Model,
session_update_tx: mpsc::UnboundedSender<(acp::SessionNotification, oneshot::Sender<()>)>,
next_session_id: Cell<u64>,
mcp_client: Option<Arc<McpClient>>,
Expand Down Expand Up @@ -86,6 +89,37 @@ impl StakpakAcpAgent {
Arc::new(client)
};

// Get default model - use smart_model from config or first available model
let model = if let Some(smart_model_str) = &config.smart_model {
// Parse the smart_model string to determine provider
let provider = if smart_model_str.starts_with("anthropic/")
|| smart_model_str.contains("claude")
{
"anthropic"
} else if smart_model_str.starts_with("openai/") || smart_model_str.contains("gpt") {
"openai"
} else if smart_model_str.starts_with("google/") || smart_model_str.contains("gemini") {
"google"
} else {
"stakpak"
};
Model::custom(smart_model_str.clone(), provider)
} else {
// Use first available model from client
let models = client.list_models().await;
models.into_iter().next().unwrap_or_else(|| {
// Fallback default: Claude Opus via Stakpak
Model::new(
"anthropic/claude-opus-4-5-20251101",
"Claude Opus 4.5",
"stakpak",
true,
None,
ModelLimit::default(),
)
})
};

// Initialize MCP client and tools (optional for ACP)
let (mcp_client, mcp_tools, tools) =
match Self::initialize_mcp_server_and_tools(&config).await {
Expand Down Expand Up @@ -114,6 +148,7 @@ impl StakpakAcpAgent {
Ok(Self {
config,
client,
model,
session_update_tx,
next_session_id: Cell::new(0),
mcp_client,
Expand Down Expand Up @@ -1081,7 +1116,7 @@ impl StakpakAcpAgent {
id: "".to_string(),
object: "".to_string(),
created: 0,
model: AgentModel::Smart.to_string(),
model: "claude-sonnet-4-5".to_string(),
choices: vec![],
usage: LLMTokenUsage {
prompt_tokens: 0,
Expand Down Expand Up @@ -1372,6 +1407,7 @@ impl StakpakAcpAgent {
let agent = StakpakAcpAgent {
config: self.config.clone(),
client: self.client.clone(),
model: self.model.clone(),
session_update_tx: tx.clone(),
next_session_id: self.next_session_id.clone(),
mcp_client,
Expand Down Expand Up @@ -1492,6 +1528,7 @@ impl Clone for StakpakAcpAgent {
Self {
config: self.config.clone(),
client: self.client.clone(),
model: self.model.clone(),
session_update_tx: self.session_update_tx.clone(),
next_session_id: Cell::new(self.next_session_id.get()),
mcp_client: self.mcp_client.clone(),
Expand Down Expand Up @@ -1736,7 +1773,7 @@ impl acp::Agent for StakpakAcpAgent {

let (stream, _request_id) = self
.client
.chat_completion_stream(AgentModel::Smart, messages, tools_option.clone(), None)
.chat_completion_stream(self.model.clone(), messages, tools_option.clone(), None)
.await
.map_err(|e| {
log::error!("Chat completion stream failed: {e}");
Expand Down Expand Up @@ -1932,7 +1969,7 @@ impl acp::Agent for StakpakAcpAgent {
let (follow_up_stream, _request_id) = self
.client
.chat_completion_stream(
AgentModel::Smart,
self.model.clone(),
current_messages.clone(),
tools_option.clone(),
None,
Expand Down
6 changes: 3 additions & 3 deletions cli/src/commands/agent/run/mode_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ use crate::commands::agent::run::tooling::run_tool_call;
use crate::config::AppConfig;
use crate::utils::agents_md::AgentsMdInfo;
use crate::utils::local_context::LocalContext;
use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, models::ListRuleBook};
use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, Model, models::ListRuleBook};
use stakpak_mcp_server::EnabledToolsConfig;
use stakpak_shared::local_store::LocalStore;
use stakpak_shared::models::integrations::openai::{AgentModel, ChatMessage};
use stakpak_shared::models::integrations::openai::ChatMessage;
use stakpak_shared::models::llm::LLMTokenUsage;
use stakpak_shared::models::subagent::SubagentConfigs;
use std::time::Instant;
Expand All @@ -35,7 +35,7 @@ pub struct RunAsyncConfig {
pub enable_mtls: bool,
pub system_prompt: Option<String>,
pub enabled_tools: EnabledToolsConfig,
pub model: AgentModel,
pub model: Model,
pub agents_md: Option<AgentsMdInfo>,
}

Expand Down
22 changes: 16 additions & 6 deletions cli/src/commands/agent/run/mode_interactive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ use crate::utils::check_update::get_latest_cli_version;
use crate::utils::local_context::LocalContext;
use reqwest::header::HeaderMap;
use stakpak_api::models::ApiStreamError;
use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, models::ListRuleBook};
use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider, Model, models::ListRuleBook};

use stakpak_mcp_server::EnabledToolsConfig;
use stakpak_shared::models::integrations::mcp::CallToolResultExt;
use stakpak_shared::models::integrations::openai::{
AgentModel, ChatMessage, MessageContent, Role, ToolCall, ToolCallResultStatus,
ChatMessage, MessageContent, Role, ToolCall, ToolCallResultStatus,
};
use stakpak_shared::models::llm::{LLMTokenUsage, PromptTokensDetails};
use stakpak_shared::models::subagent::SubagentConfigs;
Expand Down Expand Up @@ -57,7 +57,7 @@ pub struct RunInteractiveConfig {
pub allowed_tools: Option<Vec<String>>,
pub auto_approve: Option<Vec<String>>,
pub enabled_tools: EnabledToolsConfig,
pub model: AgentModel,
pub model: Model,
pub agents_md: Option<AgentsMdInfo>,
}

Expand Down Expand Up @@ -118,8 +118,8 @@ pub async fn run_interactive(
});
let editor_command = ctx.editor.clone();

let model_clone = model.clone();
let auth_display_info_for_tui = ctx.get_auth_display_info();
let model_for_tui = model.clone();
let tui_handle = tokio::spawn(async move {
let latest_version = get_latest_cli_version().await;
stakpak_tui::run_tui(
Expand All @@ -135,7 +135,7 @@ pub async fn run_interactive(
allowed_tools.as_ref(),
current_profile_for_tui,
rulebook_config_for_tui,
model_clone,
model_for_tui,
editor_command,
auth_display_info_for_tui,
)
Expand Down Expand Up @@ -288,7 +288,7 @@ pub async fn run_interactive(

while let Some(output_event) = output_rx.recv().await {
match output_event {
OutputEvent::SwitchModel(new_model) => {
OutputEvent::SwitchToModel(new_model) => {
model = new_model;
continue;
}
Expand Down Expand Up @@ -814,6 +814,16 @@ pub async fn run_interactive(
.await?;
continue;
}
OutputEvent::RequestAvailableModels => {
// Load available models from the provider registry
let available_models = client.list_models().await;
send_input_event(
&input_tx,
InputEvent::AvailableModelsLoaded(available_models),
)
.await?;
continue;
}
}

let headers = if study_mode {
Expand Down
45 changes: 45 additions & 0 deletions cli/src/config/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ pub struct AppConfig {
pub eco_model: Option<String>,
/// Recovery model name
pub recovery_model: Option<String>,
/// New unified model field (replaces smart/eco/recovery model selection)
pub model: Option<String>,
/// Unique ID for anonymous telemetry
pub anonymous_id: Option<String>,
/// Whether to collect telemetry data
Expand Down Expand Up @@ -162,6 +164,7 @@ impl AppConfig {
smart_model: profile_config.smart_model,
eco_model: profile_config.eco_model,
recovery_model: profile_config.recovery_model,
model: profile_config.model,
anonymous_id: settings.anonymous_id,
collect_telemetry: settings.collect_telemetry,
editor: settings.editor,
Expand Down Expand Up @@ -776,6 +779,47 @@ impl AppConfig {

(config_provider, None, None)
}

/// Get the default Model from config
///
/// Uses the `model` field if set, otherwise falls back to `smart_model`,
/// and finally to a default Claude Opus model.
///
/// Searches the model catalog by ID. If the model string has a provider
/// prefix (e.g., "anthropic/claude-opus-4-5"), it searches within that
/// provider first. Otherwise, it searches all providers.
pub fn get_default_model(&self) -> stakpak_api::Model {
let use_stakpak = self.api_key.is_some();

// Priority: model > smart_model > default
let model_str = self
.model
.as_ref()
.or(self.smart_model.as_ref())
.map(|s| s.as_str())
.unwrap_or("claude-opus-4-5-20251101");

// Search the model catalog
stakpak_api::find_model(model_str, use_stakpak).unwrap_or_else(|| {
// Model not found in catalog - create a custom model
// Extract provider from prefix if present
let (provider, model_id) = if let Some(idx) = model_str.find('/') {
let (prefix, rest) = model_str.split_at(idx);
(prefix, &rest[1..])
} else {
("anthropic", model_str) // Default to anthropic
};

let final_provider = if use_stakpak { "stakpak" } else { provider };
let final_id = if use_stakpak {
format!("{}/{}", provider, model_id)
} else {
model_id.to_string()
};

stakpak_api::Model::custom(final_id, final_provider)
})
}
}

// Conversions
Expand Down Expand Up @@ -810,6 +854,7 @@ impl From<AppConfig> for ProfileConfig {
eco_model: config.eco_model,
smart_model: config.smart_model,
recovery_model: config.recovery_model,
model: config.model,
}
}
}
Expand Down
23 changes: 20 additions & 3 deletions cli/src/config/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,22 @@ pub struct ProfileConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub anthropic: Option<AnthropicConfig>,

/// Eco (fast/cheap) model name
/// User's preferred model (replaces smart_model/eco_model/recovery_model)
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,

// =========================================================================
// Legacy model fields - kept for backward compatibility during migration
// These are read but deprecated (will migrate to 'model' field)
// =========================================================================
/// Eco (fast/cheap) model name (deprecated - use 'model')
#[serde(skip_serializing_if = "Option::is_none")]
pub eco_model: Option<String>,
/// Smart (capable) model name
/// Smart (capable) model name (deprecated - use 'model')
#[serde(skip_serializing_if = "Option::is_none")]
pub smart_model: Option<String>,
/// Recovery model name
/// Recovery model name (deprecated - use 'model')
#[serde(skip_serializing_if = "Option::is_none")]
pub recovery_model: Option<String>,
}

Expand Down Expand Up @@ -304,6 +315,12 @@ impl ProfileConfig {
.gemini
.clone()
.or_else(|| other.and_then(|config| config.gemini.clone())),
// New unified model field
model: self
.model
.clone()
.or_else(|| other.and_then(|config| config.model.clone())),
// Legacy fields - merge for backward compatibility during transition
eco_model: self
.eco_model
.clone()
Expand Down
2 changes: 2 additions & 0 deletions cli/src/config/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ fn sample_app_config(profile_name: &str) -> AppConfig {
smart_model: None,
eco_model: None,
recovery_model: None,
model: None,
anonymous_id: Some("test-user-id".into()),
collect_telemetry: Some(true),
editor: Some("nano".into()),
Expand Down Expand Up @@ -430,6 +431,7 @@ fn save_writes_profile_and_settings() {
smart_model: None,
eco_model: None,
recovery_model: None,
model: None,
anonymous_id: Some("test-user-id".into()),
collect_telemetry: Some(true),
editor: Some("nano".into()),
Expand Down
11 changes: 4 additions & 7 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use names::{self, Name};
use rustls::crypto::CryptoProvider;
use stakpak_api::{AgentClient, AgentClientConfig, AgentProvider};
use stakpak_mcp_server::EnabledToolsConfig;
use stakpak_shared::models::{integrations::openai::AgentModel, subagent::SubagentConfigs};
use stakpak_shared::models::subagent::SubagentConfigs;
use std::{
env,
path::{Path, PathBuf},
Expand Down Expand Up @@ -122,10 +122,6 @@ struct Cli {
#[arg(long = "profile")]
profile: Option<String>,

/// Choose agent model on startup (smart or eco)
#[arg(long = "model")]
model: Option<AgentModel>,

/// Custom path to config file (overrides default ~/.stakpak/config.toml)
#[arg(long = "config")]
config_path: Option<PathBuf>,
Expand Down Expand Up @@ -430,6 +426,7 @@ async fn main() {

let allowed_tools = cli.allowed_tools.or_else(|| config.allowed_tools.clone());
let auto_approve = config.auto_approve.clone();
let default_model = config.get_default_model();

match use_async_mode {
// Async mode: run continuously until no more tool calls (or max_steps=1 for single-step)
Expand All @@ -452,7 +449,7 @@ async fn main() {
enabled_tools: EnabledToolsConfig {
slack: cli.enable_slack_tools,
},
model: cli.model.unwrap_or(AgentModel::Smart),
model: default_model.clone(),
agents_md: agents_md.clone(),
},
)
Expand Down Expand Up @@ -484,7 +481,7 @@ async fn main() {
enabled_tools: EnabledToolsConfig {
slack: cli.enable_slack_tools,
},
model: cli.model.unwrap_or(AgentModel::Smart),
model: default_model,
agents_md,
},
)
Expand Down
Loading