diff --git a/Cargo.lock b/Cargo.lock index 0267389343..57e7e0d4fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -852,6 +852,26 @@ version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" +[[package]] +name = "config" +version = "0.15.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b30fa8254caad766fc03cb0ccae691e14bf3bd72bfff27f72802ce729551b3d6" +dependencies = [ + "async-trait", + "convert_case 0.6.0", + "json5", + "pathdiff", + "ron", + "rust-ini", + "serde-untagged", + "serde_core", + "serde_json", + "toml 0.9.8", + "winnow", + "yaml-rust2", +] + [[package]] name = "console" version = "0.15.11" @@ -876,6 +896,35 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom 0.2.16", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "convert_case" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec182b0ca2f35d8fc196cf3404988fd8b8c739a4d270ff118a398feb0cbec1ca" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "convert_case" version = "0.10.0" @@ -1510,6 +1559,15 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "dlv-list" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "442039f5147480ba31067cb00ada1adae6892028e40e45fc5de7b7df6dcc1b5f" +dependencies = [ + "const-random", +] + [[package]] name = "document-features" version = "0.2.12" @@ -1623,6 +1681,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "erased-serde" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2add8a07dd6a8d93ff627029c51de145e12686fbc36ecb298ac22e74cf02dec" +dependencies = [ + "serde", + "serde_core", + "typeid", +] + [[package]] name = "errno" version = "0.3.14" @@ -1866,6 +1935,26 @@ dependencies = [ "serde_json", ] +[[package]] +name = "forge_config" +version = "0.1.0" +dependencies = [ + "config", + "derive_setters", + "dirs", + "dotenvy", + "fake", + "merge", + "pretty_assertions", + "schemars 1.2.1", + "serde", + "serde_json", + "thiserror 2.0.18", + "toml_edit", + "tracing", + "url", +] + [[package]] name = "forge_display" version = "0.1.0" @@ -2102,6 +2191,7 @@ dependencies = [ "eventsource-stream", "fake", "forge_app", + "forge_config", "forge_domain", "forge_fs", "forge_infra", @@ -2759,6 +2849,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -3478,6 +3574,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "json5" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b0db21af676c1ce64250b5f40f3ce2cf27e4e47cb91ed91eb6fe9350b430c1" +dependencies = [ + "pest", + "pest_derive", + "serde", +] + [[package]] name = "jsonwebtoken" version = "10.3.0" @@ -4175,6 +4282,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-multimap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49203cdcae0030493bad186b28da2fa25645fa276a51b6fec8010d281e02ef79" +dependencies = [ + "dlv-list", + "hashbrown 0.14.5", +] + [[package]] name = "outref" version = "0.5.2" @@ -5119,6 +5236,30 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "ron" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd490c5b18261893f14449cbd28cb9c0b637aebf161cd77900bfdedaff21ec32" +dependencies = [ + "bitflags 2.10.0", + "once_cell", + "serde", + "serde_derive", + "typeid", + "unicode-ident", +] + +[[package]] +name = "rust-ini" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "796e8d2b6696392a43bea58116b667fb4c29727dc5abd27d6acf338bb4f688c7" +dependencies = [ + "cfg-if", + "ordered-multimap", +] + [[package]] name = "rustc-hash" version = "2.1.1" @@ -5437,6 +5578,18 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-untagged" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9faf48a4a2d2693be24c6289dbe26552776eb7737074e6722891fadbe6c5058" +dependencies = [ + "erased-serde", + "serde", + "serde_core", + "typeid", +] + [[package]] name = "serde_core" version = "1.0.228" @@ -6269,6 +6422,15 @@ dependencies = [ "time-core", ] +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tinystr" version = "0.8.2" @@ -6649,6 +6811,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typeid" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc7d623258602320d5c55d1bc22793b57daff0ec7efc270ea7d55ce1d5f5471c" + [[package]] name = "typenum" version = "1.19.0" diff --git a/Cargo.toml b/Cargo.toml index e461963898..60a52ef73a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,6 +91,7 @@ tempfile = "3.27.0" termimad = "0.34.1" syntect = { version = "5", default-features = false, features = ["default-syntaxes", "default-themes", "regex-onig"] } thiserror = "2.0.18" +toml_edit = { version = "0.22", features = ["serde"] } tokio = { version = "1.50.0", features = [ "macros", "rt-multi-thread", @@ -157,3 +158,4 @@ forge_select = { path = "crates/forge_select" } forge_test_kit = { path = "crates/forge_test_kit" } forge_markdown_stream = { path = "crates/forge_markdown_stream" } +forge_config = { path = "crates/forge_config" } diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index 037ae8565f..a2dc7847a7 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -41,18 +41,9 @@ impl ForgeAPI { } impl ForgeAPI>, ForgeRepo> { - pub fn init( - restricted: bool, - cwd: PathBuf, - override_model: Option, - override_provider: Option, - ) -> Self { + pub fn init(restricted: bool, cwd: PathBuf) -> Self { let infra = Arc::new(ForgeInfra::new(restricted, cwd)); - let repo = Arc::new(ForgeRepo::new( - infra.clone(), - override_model, - override_provider, - )); + let repo = Arc::new(ForgeRepo::new(infra.clone())); let app = Arc::new(ForgeServices::new(repo.clone())); ForgeAPI::new(app, repo) } diff --git a/crates/forge_config/.forge.toml b/crates/forge_config/.forge.toml new file mode 100644 index 0000000000..678f6b283e --- /dev/null +++ b/crates/forge_config/.forge.toml @@ -0,0 +1,59 @@ +max_search_lines = 1000 +max_search_result_bytes = 10240 +max_fetch_chars = 50000 +max_stdout_prefix_lines = 100 +max_stdout_suffix_lines = 100 +max_stdout_line_chars = 500 +max_line_chars = 2000 +max_read_lines = 2000 +max_file_read_batch_size = 50 +max_file_size_bytes = 104857600 +max_image_size_bytes = 262144 +tool_timeout_secs = 300 +auto_open_dump = false +max_conversations = 100 +max_sem_search_results = 100 +sem_search_top_k = 10 +workspace_server_url = "https://api.forgecode.dev/" +max_extensions = 15 +max_parallel_file_reads = 64 +model_cache_ttl_secs = 604800 +max_requests_per_turn = 100 +max_tool_failure_per_turn = 3 +top_p = 0.8 +top_k = 30 +max_tokens = 20480 + +[retry] +initial_backoff_ms = 200 +min_delay_ms = 1000 +backoff_factor = 2 +max_attempts = 8 +status_codes = [429, 500, 502, 503, 504, 408, 522, 520, 529] +suppress_errors = false + +[http] +connect_timeout_secs = 30 +read_timeout_secs = 900 +pool_idle_timeout_secs = 90 +pool_max_idle_per_host = 5 +max_redirects = 10 +hickory = false +tls_backend = "default" +adaptive_window = true +keep_alive_interval_secs = 60 +keep_alive_timeout_secs = 10 +keep_alive_while_idle = true +accept_invalid_certs = false + +[compact] +max_tokens = 2000 +token_threshold = 100000 +retention_window = 6 +message_threshold = 200 +eviction_window = 0.2 +on_turn_end = false + +[updates] +frequency = "daily" +auto_update = true diff --git a/crates/forge_config/Cargo.toml b/crates/forge_config/Cargo.toml new file mode 100644 index 0000000000..a9b2acfb7a --- /dev/null +++ b/crates/forge_config/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "forge_config" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true + +[dependencies] +thiserror.workspace = true +config = { version = "0.15", features = ["toml"] } +derive_setters.workspace = true +dirs.workspace = true +dotenvy.workspace = true +serde.workspace = true +serde_json.workspace = true +toml_edit = { workspace = true } +url.workspace = true +fake = { version = "5.1.0", features = ["derive"] } +schemars.workspace = true +merge.workspace = true +tracing.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/forge_config/src/auto_dump.rs b/crates/forge_config/src/auto_dump.rs new file mode 100644 index 0000000000..a40adf653a --- /dev/null +++ b/crates/forge_config/src/auto_dump.rs @@ -0,0 +1,24 @@ +use serde::{Deserialize, Serialize}; + +/// The output format used when auto-dumping a conversation on task completion. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, fake::Dummy)] +#[serde(rename_all = "snake_case")] +pub enum AutoDumpFormat { + /// Dump as a JSON file + Json, + /// Dump as an HTML file + Html, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_auto_dump_format_variants() { + assert_eq!(AutoDumpFormat::Json, AutoDumpFormat::Json); + assert_eq!(AutoDumpFormat::Html, AutoDumpFormat::Html); + } +} diff --git a/crates/forge_config/src/compact.rs b/crates/forge_config/src/compact.rs new file mode 100644 index 0000000000..4b587856ca --- /dev/null +++ b/crates/forge_config/src/compact.rs @@ -0,0 +1,511 @@ +use std::fmt; +use std::ops::Deref; +use std::time::Duration; + +use derive_setters::Setters; +use schemars::JsonSchema; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A newtype for temperature values with built-in validation +/// +/// Temperature controls the randomness in the model's output: +/// - Lower values (e.g., 0.1) make responses more focused, deterministic, and +/// coherent +/// - Higher values (e.g., 0.8) make responses more creative, diverse, and +/// exploratory +/// - Valid range is 0.0 to 2.0 +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, JsonSchema)] +pub struct Temperature(f32); + +impl Temperature { + /// Creates a new Temperature value, returning an error if outside the valid + /// range (0.0 to 2.0) + pub fn new(value: f32) -> Result { + if Self::is_valid(value) { + Ok(Self(value)) + } else { + Err(format!( + "temperature must be between 0.0 and 2.0, got {value}" + )) + } + } + + /// Creates a new Temperature value without validation + /// + /// # Safety + /// This function should only be used when the value is known to be valid + pub fn new_unchecked(value: f32) -> Self { + debug_assert!(Self::is_valid(value), "invalid temperature: {value}"); + Self(value) + } + + /// Returns true if the temperature value is within the valid range (0.0 to + /// 2.0) + pub fn is_valid(value: f32) -> bool { + (0.0..=2.0).contains(&value) + } + + /// Returns the inner f32 value + pub fn value(&self) -> f32 { + self.0 + } +} + +impl Deref for Temperature { + type Target = f32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for f32 { + fn from(temp: Temperature) -> Self { + temp.0 + } +} + +impl From for Temperature { + fn from(value: f32) -> Self { + Temperature::new_unchecked(value) + } +} + +impl fmt::Display for Temperature { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Serialize for Temperature { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let formatted = format!("{:.1}", self.0); + let value = formatted.parse::().unwrap(); + serializer.serialize_f32(value) + } +} + +impl<'de> Deserialize<'de> for Temperature { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::Error; + let value = f32::deserialize(deserializer)?; + if Self::is_valid(value) { + Ok(Self(value)) + } else { + Err(Error::custom(format!( + "temperature must be between 0.0 and 2.0, got {value}" + ))) + } + } +} + +/// A newtype for top_p values with built-in validation +/// +/// Top-p (nucleus sampling) controls the diversity of the model's output: +/// - Lower values (e.g., 0.1) make responses more focused by considering only +/// the most probable tokens +/// - Higher values (e.g., 0.9) make responses more diverse by considering a +/// broader range of tokens +/// - Valid range is 0.0 to 1.0 +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, JsonSchema)] +pub struct TopP(f32); + +impl TopP { + /// Creates a new TopP value, returning an error if outside the valid + /// range (0.0 to 1.0) + pub fn new(value: f32) -> Result { + if Self::is_valid(value) { + Ok(Self(value)) + } else { + Err(format!("top_p must be between 0.0 and 1.0, got {value}")) + } + } + + /// Creates a new TopP value without validation + /// + /// # Safety + /// This function should only be used when the value is known to be valid + pub fn new_unchecked(value: f32) -> Self { + debug_assert!(Self::is_valid(value), "invalid top_p: {value}"); + Self(value) + } + + /// Returns true if the top_p value is within the valid range (0.0 to 1.0) + pub fn is_valid(value: f32) -> bool { + (0.0..=1.0).contains(&value) + } + + /// Returns the inner f32 value + pub fn value(&self) -> f32 { + self.0 + } +} + +impl Deref for TopP { + type Target = f32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for f32 { + fn from(top_p: TopP) -> Self { + top_p.0 + } +} + +impl fmt::Display for TopP { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Serialize for TopP { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let formatted = format!("{:.2}", self.0); + let value = formatted.parse::().unwrap(); + serializer.serialize_f32(value) + } +} + +impl<'de> Deserialize<'de> for TopP { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::Error; + let value = f32::deserialize(deserializer)?; + if Self::is_valid(value) { + Ok(Self(value)) + } else { + Err(Error::custom(format!( + "top_p must be between 0.0 and 1.0, got {value}" + ))) + } + } +} + +/// A newtype for top_k values with built-in validation +/// +/// Top-k controls the number of highest probability vocabulary tokens to keep: +/// - Lower values (e.g., 10) make responses more focused by considering only +/// the top K most likely tokens +/// - Higher values (e.g., 100) make responses more diverse by considering more +/// token options +/// - Valid range is 1 to 1000 (inclusive) +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, JsonSchema)] +pub struct TopK(u32); + +impl TopK { + /// Creates a new TopK value, returning an error if outside the valid + /// range (1 to 1000) + pub fn new(value: u32) -> Result { + if Self::is_valid(value) { + Ok(Self(value)) + } else { + Err(format!("top_k must be between 1 and 1000, got {value}")) + } + } + + /// Creates a new TopK value without validation + /// + /// # Safety + /// This function should only be used when the value is known to be valid + pub fn new_unchecked(value: u32) -> Self { + debug_assert!(Self::is_valid(value), "invalid top_k: {value}"); + Self(value) + } + + /// Returns true if the top_k value is within the valid range (1 to 1000) + pub fn is_valid(value: u32) -> bool { + (1..=1000).contains(&value) + } + + /// Returns the inner u32 value + pub fn value(&self) -> u32 { + self.0 + } +} + +impl Deref for TopK { + type Target = u32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for u32 { + fn from(top_k: TopK) -> Self { + top_k.0 + } +} + +impl fmt::Display for TopK { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Serialize for TopK { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_u32(self.0) + } +} + +impl<'de> Deserialize<'de> for TopK { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::Error; + let value = u32::deserialize(deserializer)?; + if Self::is_valid(value) { + Ok(Self(value)) + } else { + Err(Error::custom(format!( + "top_k must be between 1 and 1000, got {value}" + ))) + } + } +} + +/// A newtype for max_tokens values with built-in validation +/// +/// Max tokens controls the maximum number of tokens the model can generate: +/// - Lower values (e.g., 100) limit response length for concise outputs +/// - Higher values (e.g., 4000) allow for longer, more detailed responses +/// - Valid range is 1 to 100,000 (reasonable upper bound for most models) +/// - If not specified, the model provider's default will be used +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, JsonSchema)] +pub struct MaxTokens(u32); + +impl MaxTokens { + /// Creates a new MaxTokens value, returning an error if outside the valid + /// range (1 to 100,000) + pub fn new(value: u32) -> Result { + if Self::is_valid(value) { + Ok(Self(value)) + } else { + Err(format!( + "max_tokens must be between 1 and 100000, got {value}" + )) + } + } + + /// Creates a new MaxTokens value without validation + /// + /// # Safety + /// This function should only be used when the value is known to be valid + pub fn new_unchecked(value: u32) -> Self { + debug_assert!(Self::is_valid(value), "invalid max_tokens: {value}"); + Self(value) + } + + /// Returns true if the max_tokens value is within the valid range (1 to + /// 100,000) + pub fn is_valid(value: u32) -> bool { + (1..=100_000).contains(&value) + } + + /// Returns the inner u32 value + pub fn value(&self) -> u32 { + self.0 + } +} + +impl Deref for MaxTokens { + type Target = u32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for u32 { + fn from(max_tokens: MaxTokens) -> Self { + max_tokens.0 + } +} + +impl fmt::Display for MaxTokens { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Serialize for MaxTokens { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_u32(self.0) + } +} + +impl<'de> Deserialize<'de> for MaxTokens { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::Error; + let value = u32::deserialize(deserializer)?; + if Self::is_valid(value) { + Ok(Self(value)) + } else { + Err(Error::custom(format!( + "max_tokens must be between 1 and 100000, got {value}" + ))) + } + } +} + +/// Frequency at which forge checks for updates +#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum UpdateFrequency { + Daily, + Weekly, + #[default] + Always, +} + +impl From for Duration { + fn from(val: UpdateFrequency) -> Self { + match val { + UpdateFrequency::Daily => Duration::from_secs(60 * 60 * 24), + UpdateFrequency::Weekly => Duration::from_secs(60 * 60 * 24 * 7), + UpdateFrequency::Always => Duration::ZERO, + } + } +} + +/// Configuration for automatic forge updates +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema, Setters, PartialEq)] +#[setters(strip_option, into)] +pub struct Update { + /// How frequently forge checks for updates + pub frequency: Option, + /// Whether to automatically install updates without prompting + pub auto_update: Option, +} + +fn deserialize_percentage<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + use serde::de::Error; + + let value = f64::deserialize(deserializer)?; + if !(0.0..=1.0).contains(&value) { + return Err(Error::custom(format!( + "percentage must be between 0.0 and 1.0, got {value}" + ))); + } + Ok(value) +} + +/// Optional tag name used when extracting summarized content during compaction +#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema, PartialEq)] +#[serde(transparent)] +pub struct SummaryTag(String); + +impl Default for SummaryTag { + fn default() -> Self { + SummaryTag("forge_context_summary".to_string()) + } +} + +impl SummaryTag { + /// Returns the inner string slice + pub fn as_str(&self) -> &str { + self.0.as_str() + } +} + +/// Configuration for automatic context compaction for all agents +#[derive(Debug, Clone, Serialize, Deserialize, Setters, JsonSchema, PartialEq)] +#[setters(strip_option, into)] +pub struct Compact { + /// Number of most recent messages to preserve during compaction. + /// These messages won't be considered for summarization. Works alongside + /// eviction_window - the more conservative limit (fewer messages to + /// compact) takes precedence. + #[serde(default)] + pub retention_window: usize, + + /// Maximum percentage of the context that can be summarized during + /// compaction. Valid values are between 0.0 and 1.0, where 0.0 means no + /// compaction and 1.0 allows summarizing all messages. Works alongside + /// retention_window - the more conservative limit (fewer messages to + /// compact) takes precedence. + #[serde(default, deserialize_with = "deserialize_percentage")] + pub eviction_window: f64, + + /// Maximum number of tokens to keep after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// Maximum number of tokens before triggering compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub token_threshold: Option, + + /// Maximum number of conversation turns before triggering compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub turn_threshold: Option, + + /// Maximum number of messages before triggering compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub message_threshold: Option, + + /// Model ID to use for compaction, useful when compacting with a + /// cheaper/faster model. If not specified, the root level model will be + /// used. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// Optional tag name to extract content from when summarizing (e.g., + /// "summary") + #[serde(skip_serializing_if = "Option::is_none")] + pub summary_tag: Option, + + /// Whether to trigger compaction when the last message is from a user + #[serde(default, skip_serializing_if = "Option::is_none")] + pub on_turn_end: Option, +} + +impl Default for Compact { + fn default() -> Self { + Self::new() + } +} + +impl Compact { + /// Creates a new compaction configuration with all optional fields unset + pub fn new() -> Self { + Self { + max_tokens: None, + token_threshold: None, + turn_threshold: None, + message_threshold: None, + summary_tag: None, + model: None, + eviction_window: 0.2, + retention_window: 0, + on_turn_end: None, + } + } +} diff --git a/crates/forge_config/src/config.rs b/crates/forge_config/src/config.rs new file mode 100644 index 0000000000..37862df654 --- /dev/null +++ b/crates/forge_config/src/config.rs @@ -0,0 +1,166 @@ +use std::path::PathBuf; + +use derive_setters::Setters; +use serde::{Deserialize, Serialize}; +use url::Url; + +use crate::reader::ConfigReader; +use crate::writer::ConfigWriter; +use crate::{ + AutoDumpFormat, Compact, HttpConfig, MaxTokens, ModelConfig, RetryConfig, Temperature, TopK, + TopP, Update, +}; + +/// Top-level Forge configuration merged from all sources (defaults, file, +/// environment). +#[derive(Default, Debug, Setters, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[setters(strip_option)] +pub struct ForgeConfig { + /// Configuration for the retry mechanism + pub retry: Option, + /// The maximum number of lines returned for FSSearch + pub max_search_lines: usize, + /// Maximum bytes allowed for search results + pub max_search_result_bytes: usize, + /// Maximum characters for fetch content + pub max_fetch_chars: usize, + /// Maximum lines for shell output prefix + pub max_stdout_prefix_lines: usize, + /// Maximum lines for shell output suffix + pub max_stdout_suffix_lines: usize, + /// Maximum characters per line for shell output + pub max_stdout_line_chars: usize, + /// Maximum characters per line for file read operations + pub max_line_chars: usize, + /// Maximum number of lines to read from a file + pub max_read_lines: u64, + /// Maximum number of files that can be read in a single batch operation + pub max_file_read_batch_size: usize, + /// HTTP configuration + pub http: Option, + /// Maximum file size in bytes for operations + pub max_file_size_bytes: u64, + /// Maximum image file size in bytes for binary read operations + pub max_image_size_bytes: u64, + /// Maximum execution time in seconds for a single tool call + pub tool_timeout_secs: u64, + /// Whether to automatically open HTML dump files in the browser + pub auto_open_dump: bool, + /// Path where debug request files should be written + pub debug_requests: Option, + /// Custom history file path + pub custom_history_path: Option, + /// Maximum number of conversations to show in list + pub max_conversations: usize, + /// Maximum number of results to return from initial vector search + pub max_sem_search_results: usize, + /// Top-k parameter for relevance filtering during semantic search + pub sem_search_top_k: usize, + /// URL for the indexing server + pub workspace_server_url: Option, + /// Maximum number of file extensions to include in the system prompt + pub max_extensions: usize, + /// Format for automatically creating a dump when a task is completed + pub auto_dump: Option, + /// Maximum number of files read concurrently in parallel operations + pub max_parallel_file_reads: usize, + /// TTL in seconds for the model API list cache + pub model_cache_ttl_secs: u64, + /// Default model and provider configuration used when not overridden by + /// individual agents. + #[serde(default)] + pub session: Option, + /// Provider and model to use for commit message generation + #[serde(default)] + pub commit: Option, + /// Provider and model to use for shell command suggestion generation + #[serde(default)] + pub suggest: Option, + /// API key for Forge authentication + #[serde(default)] + pub api_key: Option, + /// Display name of the API key + #[serde(default)] + pub api_key_name: Option, + /// Masked representation of the API key for display purposes + #[serde(default)] + pub api_key_masked: Option, + /// Email address associated with the Forge account + #[serde(default)] + pub email: Option, + /// Display name of the authenticated user + #[serde(default)] + pub name: Option, + /// Identifier of the authentication provider used for login + #[serde(default)] + pub auth_provider_id: Option, + + // --- Workflow fields --- + /// Configuration for automatic forge updates + #[serde(skip_serializing_if = "Option::is_none")] + pub updates: Option, + + /// Output randomness for all agents; lower values are deterministic, higher + /// values are creative (0.0–2.0). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// Nucleus sampling threshold for all agents; limits token selection to the + /// top cumulative probability mass (0.0–1.0). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// Top-k vocabulary cutoff for all agents; restricts sampling to the k + /// highest-probability tokens (1–1000). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_k: Option, + + /// Maximum tokens the model may generate per response for all agents + /// (1–100,000). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// Maximum tool failures per turn before the orchestrator forces + /// completion. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tool_failure_per_turn: Option, + + /// Maximum number of requests that can be made in a single turn. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_requests_per_turn: Option, + + /// Context compaction settings applied to all agents; falls back to each + /// agent's individual setting when absent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub compact: Option, +} + +impl ForgeConfig { + /// Reads and merges configuration from all sources, returning the resolved + /// [`ForgeConfig`]. + /// + /// # Errors + /// + /// Returns an error if the config path cannot be resolved, the file cannot + /// be read, or deserialization fails. + pub fn read() -> crate::Result { + ConfigReader::default() + .read_defaults() + .read_legacy() + .read_global() + .read_env() + .build() + } + + /// Writes the configuration to the user config file. + /// + /// # Errors + /// + /// Returns an error if the configuration cannot be serialized or written to + /// disk. + pub fn write(&self) -> crate::Result<()> { + let path = ConfigReader::config_path(); + ConfigWriter::new(self.clone()).write(&path) + } +} diff --git a/crates/forge_config/src/error.rs b/crates/forge_config/src/error.rs new file mode 100644 index 0000000000..ed1b867b70 --- /dev/null +++ b/crates/forge_config/src/error.rs @@ -0,0 +1,18 @@ +/// Errors produced by the `forge_config` crate. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Failed to read or parse configuration from a file or environment. + #[error("Config error: {0}")] + Config(#[from] config::ConfigError), + + /// Failed to serialize or write configuration. + #[error("Serialization error: {0}")] + Serialization(#[from] toml_edit::ser::Error), + + /// An I/O error occurred while reading or writing configuration files. + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Serde(#[from] serde_json::Error), +} diff --git a/crates/forge_config/src/http.rs b/crates/forge_config/src/http.rs new file mode 100644 index 0000000000..e25916d458 --- /dev/null +++ b/crates/forge_config/src/http.rs @@ -0,0 +1,94 @@ +use serde::{Deserialize, Serialize}; + +/// TLS version enum for configuring TLS protocol versions. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, fake::Dummy)] +#[serde(rename_all = "snake_case")] +pub enum TlsVersion { + #[serde(rename = "1.0")] + V1_0, + #[serde(rename = "1.1")] + V1_1, + #[serde(rename = "1.2")] + V1_2, + #[serde(rename = "1.3")] + V1_3, +} + +/// TLS backend option. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, fake::Dummy)] +#[serde(rename_all = "snake_case")] +pub enum TlsBackend { + #[serde(rename = "default")] + Default, + #[serde(rename = "rustls")] + Rustls, +} + +/// HTTP client configuration. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, fake::Dummy)] +#[serde(rename_all = "snake_case")] +pub struct HttpConfig { + pub connect_timeout_secs: u64, + pub read_timeout_secs: u64, + pub pool_idle_timeout_secs: u64, + pub pool_max_idle_per_host: usize, + pub max_redirects: usize, + pub hickory: bool, + pub tls_backend: TlsBackend, + /// Minimum TLS protocol version to use + pub min_tls_version: Option, + /// Maximum TLS protocol version to use + pub max_tls_version: Option, + /// Adaptive window sizing for improved flow control + pub adaptive_window: bool, + /// Keep-alive interval in seconds + pub keep_alive_interval_secs: Option, + /// Keep-alive timeout in seconds + pub keep_alive_timeout_secs: u64, + /// Keep-alive while connection is idle + pub keep_alive_while_idle: bool, + /// Accept invalid certificates + pub accept_invalid_certs: bool, + /// Paths to root certificate files + pub root_cert_paths: Option>, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_http_config_fields() { + let config = HttpConfig { + connect_timeout_secs: 30, + read_timeout_secs: 900, + pool_idle_timeout_secs: 90, + pool_max_idle_per_host: 5, + max_redirects: 10, + hickory: false, + tls_backend: TlsBackend::Default, + min_tls_version: None, + max_tls_version: None, + adaptive_window: true, + keep_alive_interval_secs: Some(60), + keep_alive_timeout_secs: 10, + keep_alive_while_idle: true, + accept_invalid_certs: false, + root_cert_paths: None, + }; + assert_eq!(config.connect_timeout_secs, 30); + assert_eq!(config.adaptive_window, true); + } + + #[test] + fn test_tls_version_variants() { + assert_eq!(TlsVersion::V1_3, TlsVersion::V1_3); + } + + #[test] + fn test_tls_backend_variants() { + assert_eq!(TlsBackend::Default, TlsBackend::Default); + } +} diff --git a/crates/forge_config/src/legacy.rs b/crates/forge_config/src/legacy.rs new file mode 100644 index 0000000000..7310333814 --- /dev/null +++ b/crates/forge_config/src/legacy.rs @@ -0,0 +1,71 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use serde::Deserialize; + +use crate::{ForgeConfig, ModelConfig}; + +/// Intermediate representation of the legacy `~/forge/.config.json` format. +/// +/// This format stores the active provider as a top-level string and models as +/// a map from provider ID to model ID, which differs from the TOML config's +/// nested `session`, `commit`, and `suggest` sub-objects. +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct LegacyConfig { + /// The active provider ID (e.g. `"anthropic"`). + #[serde(default)] + provider: Option, + /// Map from provider ID to the model ID to use with that provider. + #[serde(default)] + model: HashMap, + /// Commit message generation provider/model pair. + #[serde(default)] + commit: Option, + /// Shell command suggestion provider/model pair. + #[serde(default)] + suggest: Option, +} + +/// A provider/model pair as expressed in the legacy JSON config. +#[derive(Debug, Deserialize)] +struct LegacyModelRef { + provider: Option, + model: Option, +} + +impl LegacyConfig { + /// Reads the legacy `~/forge/.config.json` file at `path`, parses it, and + /// returns the equivalent TOML representation as a [`String`]. + /// + /// # Errors + /// + /// Returns an error if the file cannot be read, the JSON is invalid, or the + /// resulting config cannot be serialized to TOML. + pub(crate) fn read(path: &PathBuf) -> crate::Result { + let contents = std::fs::read_to_string(path)?; + let config = serde_json::from_str::(&contents)?; + let forge_config = config.into_forge_config(); + let content = toml_edit::ser::to_string_pretty(&forge_config)?; + Ok(content) + } + + /// Converts a [`LegacyConfig`] into the fields of [`ForgeConfig`] that it + /// covers, leaving all other fields at their defaults. + fn into_forge_config(self) -> ForgeConfig { + let session = self.provider.as_deref().map(|provider_id| { + let model_id = self.model.get(provider_id).cloned(); + ModelConfig { provider_id: Some(provider_id.to_string()), model_id } + }); + + let commit = self + .commit + .map(|c| ModelConfig { provider_id: c.provider, model_id: c.model }); + + let suggest = self + .suggest + .map(|s| ModelConfig { provider_id: s.provider, model_id: s.model }); + + ForgeConfig { session, commit, suggest, ..Default::default() } + } +} diff --git a/crates/forge_config/src/lib.rs b/crates/forge_config/src/lib.rs new file mode 100644 index 0000000000..b0ba37a4b3 --- /dev/null +++ b/crates/forge_config/src/lib.rs @@ -0,0 +1,23 @@ +mod auto_dump; +mod compact; +mod config; +mod error; +mod http; +mod legacy; +mod model; +mod reader; +mod retry; +mod writer; + +pub use auto_dump::*; +pub use compact::*; +pub use config::*; +pub use error::Error; +pub use http::*; +pub use model::*; +pub use reader::*; +pub use retry::*; +pub use writer::*; + +/// A `Result` type alias for this crate's [`Error`] type. +pub type Result = std::result::Result; diff --git a/crates/forge_config/src/model.rs b/crates/forge_config/src/model.rs new file mode 100644 index 0000000000..7097759003 --- /dev/null +++ b/crates/forge_config/src/model.rs @@ -0,0 +1,18 @@ +use derive_setters::Setters; +use serde::{Deserialize, Serialize}; + +/// A type alias for a provider identifier string. +pub type ProviderId = String; + +/// A type alias for a model identifier string. +pub type ModelId = String; + +/// Pairs a provider and model together for a specific operation. +#[derive(Default, Debug, Setters, Clone, PartialEq, Serialize, Deserialize, fake::Dummy)] +#[setters(strip_option, into)] +pub struct ModelConfig { + /// The provider to use for this operation. + pub provider_id: Option, + /// The model to use for this operation. + pub model_id: Option, +} diff --git a/crates/forge_config/src/reader.rs b/crates/forge_config/src/reader.rs new file mode 100644 index 0000000000..786496bc08 --- /dev/null +++ b/crates/forge_config/src/reader.rs @@ -0,0 +1,190 @@ +use std::path::PathBuf; +use std::sync::LazyLock; + +use config::ConfigBuilder; +use config::builder::DefaultState; + +use crate::ForgeConfig; +use crate::legacy::LegacyConfig; + +/// Loads all `.env` files found while walking up from the current working +/// directory to the root, with priority given to closer (lower) directories. +/// Executed at most once per process. +static LOAD_DOT_ENV: LazyLock<()> = LazyLock::new(|| { + let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + let mut paths = vec![]; + let mut current = PathBuf::new(); + + for component in cwd.components() { + current.push(component); + paths.push(current.clone()); + } + + paths.reverse(); + + for path in paths { + let env_file = path.join(".env"); + if env_file.is_file() { + dotenvy::from_path(&env_file).ok(); + } + } +}); + +/// Merges [`ForgeConfig`] from layered sources using a builder pattern. +#[derive(Default)] +pub struct ConfigReader { + builder: ConfigBuilder, +} + +impl ConfigReader { + /// Returns the path to the legacy JSON config file + /// (`~/.forge/.config.json`). + pub fn config_legacy_path() -> PathBuf { + Self::base_path().join(".config.json") + } + + /// Returns the path to the primary TOML config file + /// (`~/.forge/.forge.toml`). + pub fn config_path() -> PathBuf { + Self::base_path().join(".forge.toml") + } + + /// Returns the base directory for all Forge config files (`~/forge`). + pub fn base_path() -> PathBuf { + dirs::home_dir().unwrap_or(PathBuf::from(".")).join("forge") + } + + /// Adds the provided TOML string as a config source without touching the + /// filesystem. + pub fn read_toml(mut self, contents: &str) -> Self { + self.builder = self + .builder + .add_source(config::File::from_str(contents, config::FileFormat::Toml)); + + self + } + + /// Adds the embedded default config (`../.forge.toml`) as a source. + pub fn read_defaults(self) -> Self { + let defaults = include_str!("../.forge.toml"); + + self.read_toml(defaults) + } + + /// Adds `FORGE_`-prefixed environment variables as a config source. + pub fn read_env(mut self) -> Self { + self.builder = self.builder.add_source( + config::Environment::with_prefix("FORGE") + .prefix_separator("_") + .separator("__") + .try_parsing(true) + .list_separator(",") + .with_list_parse_key("retry.status_codes") + .with_list_parse_key("http.root_cert_paths"), + ); + + self + } + + /// Builds and deserializes all accumulated sources into a [`ForgeConfig`]. + /// + /// Triggers `.env` file loading (at most once per process) by walking up + /// the directory tree from the current working directory, with closer + /// directories taking priority. + /// + /// # Errors + /// + /// Returns an error if the configuration cannot be built or deserialized. + pub fn build(self) -> crate::Result { + *LOAD_DOT_ENV; + let config = self.builder.build()?; + Ok(config.try_deserialize::()?) + } + + /// Adds `~/.forge/.forge.toml` as a config source, silently skipping if + /// absent. + pub fn read_global(mut self) -> Self { + let path = Self::config_path(); + self.builder = self.builder.add_source(config::File::from(path)); + self + } + + /// Reads `~/.forge/.config.json` (legacy format) and adds it as a source, + /// silently skipping errors. + pub fn read_legacy(self) -> Self { + let content = LegacyConfig::read(&Self::config_legacy_path()); + if let Ok(content) = content { + self.read_toml(&content) + } else { + self + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::{Mutex, MutexGuard}; + + use pretty_assertions::assert_eq; + + use super::*; + use crate::ModelConfig; + + /// Serializes tests that mutate environment variables to prevent races. + static ENV_MUTEX: Mutex<()> = Mutex::new(()); + + /// Holds env vars set for a test's duration and removes them on drop, while + /// holding [`ENV_MUTEX`]. + struct EnvGuard { + keys: Vec<&'static str>, + _lock: MutexGuard<'static, ()>, + } + + impl EnvGuard { + /// Sets each `(key, value)` pair in the environment, returning a guard + /// that cleans them up on drop. + #[must_use] + fn set(pairs: &[(&'static str, &str)]) -> Self { + let lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner()); + let keys = pairs.iter().map(|(k, _)| *k).collect(); + for (key, value) in pairs { + unsafe { std::env::set_var(key, value) }; + } + Self { keys, _lock: lock } + } + } + + impl Drop for EnvGuard { + fn drop(&mut self) { + for key in &self.keys { + unsafe { std::env::remove_var(key) }; + } + } + } + + #[test] + fn test_read_parses_without_error() { + let actual = ConfigReader::default().read_defaults().build(); + assert!(actual.is_ok(), "read() failed: {:?}", actual.err()); + } + + #[test] + fn test_read_session_from_env_vars() { + let _guard = EnvGuard::set(&[ + ("FORGE_SESSION__PROVIDER_ID", "fake-provider"), + ("FORGE_SESSION__MODEL_ID", "fake-model"), + ]); + + let actual = ConfigReader::default() + .read_defaults() + .read_env() + .build() + .unwrap(); + + let expected = Some(ModelConfig { + provider_id: Some("fake-provider".to_string()), + model_id: Some("fake-model".to_string()), + }); + assert_eq!(actual.session, expected); + } +} diff --git a/crates/forge_config/src/retry.rs b/crates/forge_config/src/retry.rs new file mode 100644 index 0000000000..c5fc8fa6b6 --- /dev/null +++ b/crates/forge_config/src/retry.rs @@ -0,0 +1,43 @@ +use serde::{Deserialize, Serialize}; + +/// Configuration for retry mechanism. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, fake::Dummy)] +#[serde(rename_all = "snake_case")] +pub struct RetryConfig { + /// Initial backoff delay in milliseconds for retry operations + pub initial_backoff_ms: u64, + /// Minimum delay in milliseconds between retry attempts + pub min_delay_ms: u64, + /// Backoff multiplication factor for each retry attempt + pub backoff_factor: u64, + /// Maximum number of retry attempts + pub max_attempts: usize, + /// HTTP status codes that should trigger retries + pub status_codes: Vec, + /// Maximum delay between retries in seconds + pub max_delay_secs: Option, + /// Whether to suppress retry error logging and events + pub suppress_errors: bool, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_retry_config_fields() { + let config = RetryConfig { + initial_backoff_ms: 200, + min_delay_ms: 1000, + backoff_factor: 2, + max_attempts: 8, + status_codes: vec![429, 500, 502, 503, 504, 408, 522, 520, 529], + max_delay_secs: None, + suppress_errors: false, + }; + assert_eq!(config.initial_backoff_ms, 200); + assert_eq!(config.suppress_errors, false); + } +} diff --git a/crates/forge_config/src/writer.rs b/crates/forge_config/src/writer.rs new file mode 100644 index 0000000000..e02ba43795 --- /dev/null +++ b/crates/forge_config/src/writer.rs @@ -0,0 +1,34 @@ +use std::path::Path; + +use crate::ForgeConfig; + +/// Writes a [`ForgeConfig`] to the user configuration file on disk. +pub struct ConfigWriter { + config: ForgeConfig, +} + +impl ConfigWriter { + /// Creates a new `ConfigWriter` for the given configuration. + pub fn new(config: ForgeConfig) -> Self { + Self { config } + } + + /// Serializes and writes the configuration to `path`, creating all parent + /// directories recursively if they do not already exist. + /// + /// # Errors + /// + /// Returns an error if the configuration cannot be serialized or the file + /// cannot be written. + pub fn write(&self, path: &Path) -> crate::Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + + let contents = toml_edit::ser::to_string_pretty(&self.config)?; + + std::fs::write(path, contents)?; + + Ok(()) + } +} diff --git a/crates/forge_domain/src/app_config.rs b/crates/forge_domain/src/app_config.rs index 9f1fc3894c..886df2730c 100644 --- a/crates/forge_domain/src/app_config.rs +++ b/crates/forge_domain/src/app_config.rs @@ -13,17 +13,12 @@ pub struct InitAuth { pub token: String, } -#[derive(Default, Clone, Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "camelCase")] +#[derive(Default, Clone, Debug, PartialEq)] pub struct AppConfig { pub key_info: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] pub provider: Option, - #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub model: HashMap, - #[serde(default, skip_serializing_if = "Option::is_none")] pub commit: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] pub suggest: Option, } @@ -40,3 +35,22 @@ pub struct LoginInfo { #[serde(default, skip_serializing_if = "Option::is_none")] pub auth_provider_id: Option, } + +/// All discrete mutations that can be applied to an [`AppConfig`]. +/// +/// Instead of replacing the entire config, callers describe exactly which field +/// they want to change. Implementations receive a list of operations, apply +/// each in order, and persist the result atomically. +#[derive(Debug, Clone, PartialEq)] +pub enum AppConfigOperation { + /// Set or clear the authentication token. + KeyInfo(Option), + /// Set the active provider. + SetProvider(ProviderId), + /// Set the model for the given provider. + SetModel(ProviderId, ModelId), + /// Set the commit-message generation configuration. + SetCommitConfig(CommitConfig), + /// Set the shell-command suggestion configuration. + SetSuggestConfig(SuggestConfig), +} diff --git a/crates/forge_domain/src/repo.rs b/crates/forge_domain/src/repo.rs index e3602f71ce..2f68b00beb 100644 --- a/crates/forge_domain/src/repo.rs +++ b/crates/forge_domain/src/repo.rs @@ -4,9 +4,9 @@ use anyhow::Result; use url::Url; use crate::{ - AnyProvider, AppConfig, AuthCredential, ChatCompletionMessage, Context, Conversation, - ConversationId, MigrationResult, Model, ModelId, Provider, ProviderId, ProviderTemplate, - ResultStream, SearchMatch, Skill, Snapshot, WorkspaceAuth, WorkspaceId, + AnyProvider, AppConfig, AppConfigOperation, AuthCredential, ChatCompletionMessage, Context, + Conversation, ConversationId, MigrationResult, Model, ModelId, Provider, ProviderId, + ProviderTemplate, ResultStream, SearchMatch, Skill, Snapshot, WorkspaceAuth, WorkspaceId, }; /// Repository for managing file snapshots @@ -91,8 +91,21 @@ pub trait ConversationRepository: Send + Sync { #[async_trait::async_trait] pub trait AppConfigRepository: Send + Sync { + /// Retrieves the current application configuration. + /// + /// # Errors + /// Returns an error if the configuration cannot be read. async fn get_app_config(&self) -> anyhow::Result; - async fn set_app_config(&self, config: &AppConfig) -> anyhow::Result<()>; + + /// Applies a list of configuration operations to the persisted config. + /// + /// Implementations should load the current config, apply each operation in + /// order via [`AppConfigOperation::apply`], and persist the result + /// atomically. + /// + /// # Errors + /// Returns an error if the configuration cannot be read or written. + async fn update_app_config(&self, ops: Vec) -> anyhow::Result<()>; } #[async_trait::async_trait] diff --git a/crates/forge_main/src/built_in_commands.json b/crates/forge_main/src/built_in_commands.json index 8584b2cd85..b57f5b0dc8 100644 --- a/crates/forge_main/src/built_in_commands.json +++ b/crates/forge_main/src/built_in_commands.json @@ -35,6 +35,10 @@ "command": "config", "description": "List current configuration values" }, + { + "command": "config-edit", + "description": "Open the global forge config file (~/forge/.forge.toml) in an editor [alias: ce]" + }, { "command": "new", "description": "Start new conversation [alias: n]" diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index c3f93a1db0..13d6845afc 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -63,22 +63,6 @@ pub struct Cli { #[arg(long, alias = "aid")] pub agent: Option, - /// Override the model to use for this session. - /// - /// When provided, uses this model instead of the configured default. - /// This is a runtime override and does not change the permanent - /// configuration. - #[arg(long)] - pub model: Option, - - /// Override the provider to use for this session. - /// - /// When provided, uses this provider instead of the configured default. - /// This is a runtime override and does not change the permanent - /// configuration. - #[arg(long)] - pub provider: Option, - /// Top-level subcommands. #[command(subcommand)] pub subcommands: Option, diff --git a/crates/forge_main/src/main.rs b/crates/forge_main/src/main.rs index c9b14c9d1c..02967526ee 100644 --- a/crates/forge_main/src/main.rs +++ b/crates/forge_main/src/main.rs @@ -60,16 +60,7 @@ async fn main() -> Result<()> { // Initialize the ForgeAPI with the restricted mode if specified let restricted = cli.restricted; - let cli_model = cli.model.clone(); - let cli_provider = cli.provider.clone(); - let mut ui = UI::init(cli, move || { - ForgeAPI::init( - restricted, - cwd.clone(), - cli_model.clone(), - cli_provider.clone(), - ) - })?; + let mut ui = UI::init(cli, move || ForgeAPI::init(restricted, cwd.clone()))?; ui.run().await; Ok(()) diff --git a/crates/forge_repo/Cargo.toml b/crates/forge_repo/Cargo.toml index ed9cc726a9..0e38fd986b 100644 --- a/crates/forge_repo/Cargo.toml +++ b/crates/forge_repo/Cargo.toml @@ -6,6 +6,7 @@ rust-version.workspace = true [dependencies] forge_app.workspace = true +forge_config.workspace = true forge_domain.workspace = true forge_infra.workspace = true forge_snaps.workspace = true diff --git a/crates/forge_repo/src/app_config.rs b/crates/forge_repo/src/app_config.rs index 65e9367a9a..44b67ba017 100644 --- a/crates/forge_repo/src/app_config.rs +++ b/crates/forge_repo/src/app_config.rs @@ -1,152 +1,181 @@ use std::sync::Arc; -use anyhow::bail; -use bytes::Bytes; -use forge_app::{EnvironmentInfra, FileReaderInfra, FileWriterInfra}; -use forge_domain::{AppConfig, AppConfigRepository, ModelId, ProviderId}; +use forge_config::{ConfigReader, ForgeConfig, ModelConfig}; +use forge_domain::{ + AppConfig, AppConfigOperation, AppConfigRepository, CommitConfig, LoginInfo, ModelId, + ProviderId, SuggestConfig, +}; use tokio::sync::Mutex; +use tracing::{debug, error}; -/// Repository for managing application configuration with caching support. +/// Converts a [`ForgeConfig`] into an [`AppConfig`]. /// -/// This repository uses infrastructure traits for file I/O operations and -/// maintains an in-memory cache to reduce file system access. The configuration -/// file path is automatically inferred from the environment. -#[derive(derive_setters::Setters)] -#[setters(into)] -pub struct AppConfigRepositoryImpl { - infra: Arc, - cache: Arc>>, - override_model: Option, - override_provider: Option, +/// `ForgeConfig` flattens login info as top-level fields and represents the +/// active model as a single [`ModelConfig`]. This conversion reconstructs the +/// nested [`LoginInfo`] and per-provider model map used by the domain. +fn forge_config_to_app_config(fc: ForgeConfig) -> AppConfig { + let key_info = fc.api_key.map(|api_key| LoginInfo { + api_key, + api_key_name: fc.api_key_name.unwrap_or_default(), + api_key_masked: fc.api_key_masked.unwrap_or_default(), + email: fc.email, + name: fc.name, + auth_provider_id: fc.auth_provider_id, + }); + + let (provider, model) = match fc.session { + Some(mc) => { + let provider_id = mc.provider_id.map(ProviderId::from); + let mut map = std::collections::HashMap::new(); + if let (Some(ref pid), Some(mid)) = (provider_id.clone(), mc.model_id.map(ModelId::new)) + { + map.insert(pid.clone(), mid); + } + (provider_id, map) + } + None => (None, std::collections::HashMap::new()), + }; + + let commit = fc.commit.map(|mc| CommitConfig { + provider: mc.provider_id.map(ProviderId::from), + model: mc.model_id.map(ModelId::new), + }); + + let suggest = fc.suggest.and_then(|mc| { + mc.provider_id + .zip(mc.model_id) + .map(|(pid, mid)| SuggestConfig { + provider: ProviderId::from(pid), + model: ModelId::new(mid), + }) + }); + + AppConfig { key_info, provider, model, commit, suggest } } -impl AppConfigRepositoryImpl { - pub fn new(infra: Arc) -> Self { - Self { - infra, - cache: Arc::new(Mutex::new(None)), - override_model: None, - override_provider: None, +/// Applies a single [`AppConfigOperation`] directly onto a [`ForgeConfig`] +/// in-place, bypassing the intermediate [`AppConfig`] representation. +fn apply_op(op: AppConfigOperation, fc: &mut ForgeConfig) { + match op { + AppConfigOperation::KeyInfo(Some(info)) => { + fc.api_key = Some(info.api_key); + fc.api_key_name = Some(info.api_key_name); + fc.api_key_masked = Some(info.api_key_masked); + fc.email = info.email; + fc.name = info.name; + fc.auth_provider_id = info.auth_provider_id; + } + AppConfigOperation::KeyInfo(None) => { + fc.api_key = None; + fc.api_key_name = None; + fc.api_key_masked = None; + fc.email = None; + fc.name = None; + fc.auth_provider_id = None; + } + AppConfigOperation::SetProvider(provider_id) => { + let pid = provider_id.as_ref().to_string(); + fc.session = Some(match fc.session.take() { + Some(mc) => mc.provider_id(pid), + None => ModelConfig::default().provider_id(pid), + }); + } + AppConfigOperation::SetModel(provider_id, model_id) => { + let pid = provider_id.as_ref().to_string(); + let mid = model_id.to_string(); + fc.session = Some(match fc.session.take() { + Some(mc) if mc.provider_id.as_deref() == Some(&pid) => mc.model_id(mid), + _ => ModelConfig::default().provider_id(pid).model_id(mid), + }); + } + AppConfigOperation::SetCommitConfig(commit) => { + fc.commit = commit + .provider + .as_ref() + .zip(commit.model.as_ref()) + .map(|(pid, mid)| { + ModelConfig::default() + .provider_id(pid.as_ref().to_string()) + .model_id(mid.to_string()) + }); + } + AppConfigOperation::SetSuggestConfig(suggest) => { + fc.suggest = Some( + ModelConfig::default() + .provider_id(suggest.provider.as_ref().to_string()) + .model_id(suggest.model.to_string()), + ); } } } -impl AppConfigRepositoryImpl { - /// Reads configuration from the JSON file with fallback strategies: - async fn read(&self) -> AppConfig { - let path = self.infra.get_environment().app_config(); - let content = match self.infra.read_utf8(&path).await { - Ok(content) => content, - Err(e) => { - tracing::error!( - path = %path.display(), - error = %e, - "Failed to read config file. Using default config." - ); - return AppConfig::default(); - } - }; - - // Strategy 1: Try normal parsing - serde_json::from_str::(&content) - .or_else(|_| { - // Strategy 2: Try JSON repair for syntactically broken JSON - tracing::warn!(path = %path.display(), "Failed to parse config file, attempting repair..."); - forge_json_repair::json_repair::(&content).inspect(|_| { - tracing::info!(path = %path.display(), "Successfully repaired config file"); - }) - }) - .inspect_err(|e| { - tracing::error!( - path = %path.display(), - error = %e, - "Failed to repair config file. Using default config." - ); - }) - .unwrap_or_default() - } - - async fn write(&self, config: &AppConfig) -> anyhow::Result<()> { - let path = self.infra.get_environment().app_config(); - let content = serde_json::to_string_pretty(config)?; - self.infra.write(&path, Bytes::from(content)).await?; - Ok(()) - } +/// Repository for managing application configuration with caching support. +/// +/// Uses [`ForgeConfig::read`] and [`ForgeConfig::write`] for all file I/O and +/// maintains an in-memory cache to reduce disk access. +pub struct ForgeConfigRepository { + cache: Arc>>, +} - fn get_overrides(&self) -> (Option, Option) { - (self.override_model.clone(), self.override_provider.clone()) +impl ForgeConfigRepository { + pub fn new() -> Self { + Self { cache: Arc::new(Mutex::new(None)) } } - fn apply_overrides(&self, mut config: AppConfig) -> AppConfig { - let (model, provider) = self.get_overrides(); - - // Override the default provider first - if let Some(ref provider_id) = provider { - config.provider = Some(provider_id.clone()); + /// Reads [`AppConfig`] from disk via [`ForgeConfig::read`]. + async fn read(&self) -> ForgeConfig { + let config = ForgeConfig::read(); - // If we have both provider and model overrides, ensure the model is set for - // this provider - if let Some(ref model_id) = model { - config.model.insert(provider_id.clone(), model_id.clone()); + match config { + Ok(config) => { + debug!(config = ?config, "read .forge.toml"); + config } - } - - // If only model override (no provider override), update existing provider - // models - if provider.is_none() - && let Some(model_id) = model - { - if config.model.is_empty() { - // If no models configured but we have a default provider, set the model for it - if let Some(ref default_provider) = config.provider { - config.model.insert(default_provider.clone(), model_id); - } - } else { - // Update all existing provider models - for (_, mut_model_id) in config.model.iter_mut() { - *mut_model_id = model_id.clone(); - } + Err(e) => { + // NOTE: This should never-happen + error!(error = ?e, "Failed to read config file. Using default config."); + Default::default() } } - - config } } #[async_trait::async_trait] -impl AppConfigRepository - for AppConfigRepositoryImpl -{ +impl AppConfigRepository for ForgeConfigRepository { async fn get_app_config(&self) -> anyhow::Result { // Check cache first let cache = self.cache.lock().await; - if let Some(ref cached_config) = *cache { - // Apply overrides even to cached config since overrides can change via env vars - return Ok(self.apply_overrides(cached_config.clone())); + if let Some(ref config) = *cache { + return Ok(forge_config_to_app_config(config.clone())); } drop(cache); // Cache miss, read from file let config = self.read().await; - // Update cache with the newly read config (without overrides) let mut cache = self.cache.lock().await; *cache = Some(config.clone()); - // Apply overrides to the config before returning - Ok(self.apply_overrides(config)) + Ok(forge_config_to_app_config(config)) } - async fn set_app_config(&self, config: &AppConfig) -> anyhow::Result<()> { - let (model, provider) = self.get_overrides(); + async fn update_app_config(&self, ops: Vec) -> anyhow::Result<()> { + // Load the global config + let mut fc = ConfigReader::default().read_global().build()?; - if model.is_some() || provider.is_some() { - bail!("Could not save configuration: Model or Provider was overridden") + debug!(config = ?fc, "loaded config for update"); + + // Apply each operation directly onto ForgeConfig + debug!(?ops, "applying app config operations"); + for op in ops { + apply_op(op, &mut fc); } - self.write(config).await?; + // Persist + fc.write()?; + debug!(config = ?fc, "written .forge.toml"); - // Bust the cache after successful write + // Reset cache let mut cache = self.cache.lock().await; *cache = None; @@ -156,416 +185,311 @@ impl AppC #[cfg(test)] mod tests { + use std::collections::HashMap; - use std::collections::{BTreeMap, HashMap}; - use std::path::{Path, PathBuf}; - use std::str::FromStr; - use std::sync::Mutex; - - use bytes::Bytes; - use forge_app::{EnvironmentInfra, FileReaderInfra, FileWriterInfra}; - use forge_domain::{AppConfig, Environment, ProviderId}; + use forge_config::{ForgeConfig, ModelConfig}; + use forge_domain::{ + AppConfig, AppConfigOperation, CommitConfig, LoginInfo, ModelId, ProviderId, SuggestConfig, + }; use pretty_assertions::assert_eq; - use tempfile::TempDir; - - use super::*; - - /// Mock infrastructure for testing that stores files in memory - #[derive(Clone)] - struct MockInfra { - files: Arc>>, - config_path: PathBuf, - } - impl MockInfra { - fn new(config_path: PathBuf) -> Self { - Self { files: Arc::new(Mutex::new(HashMap::new())), config_path } - } - } - - impl EnvironmentInfra for MockInfra { - fn get_environment(&self) -> Environment { - use fake::{Fake, Faker}; - let mut env: Environment = Faker.fake(); - env = env.base_path(self.config_path.parent().unwrap().to_path_buf()); - env - } + use super::{apply_op, forge_config_to_app_config}; - fn get_env_var(&self, _key: &str) -> Option { - None - } - - fn get_env_vars(&self) -> BTreeMap { - BTreeMap::new() - } - - fn is_restricted(&self) -> bool { - false - } - } + // ── forge_config_to_app_config ──────────────────────────────────────────── - #[async_trait::async_trait] - impl FileReaderInfra for MockInfra { - async fn read_utf8(&self, path: &Path) -> anyhow::Result { - self.files - .lock() - .unwrap() - .get(path) - .cloned() - .ok_or_else(|| anyhow::anyhow!("File not found")) - } - - fn read_batch_utf8( - &self, - _batch_size: usize, - _paths: Vec, - ) -> impl futures::Stream)> + Send { - futures::stream::empty() - } - - async fn read(&self, _path: &Path) -> anyhow::Result> { - unimplemented!() - } - - async fn range_read_utf8( - &self, - _path: &Path, - _start_line: u64, - _end_line: u64, - ) -> anyhow::Result<(String, forge_domain::FileInfo)> { - unimplemented!() - } - } - - #[async_trait::async_trait] - impl FileWriterInfra for MockInfra { - async fn write(&self, path: &Path, contents: Bytes) -> anyhow::Result<()> { - let content = String::from_utf8(contents.to_vec())?; - self.files - .lock() - .unwrap() - .insert(path.to_path_buf(), content); - Ok(()) - } - - async fn write_temp(&self, _: &str, _: &str, _: &str) -> anyhow::Result { - unimplemented!() - } - } - - fn repository_fixture() -> (AppConfigRepositoryImpl, TempDir) { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - let infra = Arc::new(MockInfra::new(config_path)); - (AppConfigRepositoryImpl::new(infra), temp_dir) - } - - fn repository_with_config_fixture() -> (AppConfigRepositoryImpl, TempDir) { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - - // Create a config file with default config - let config = AppConfig::default(); - let content = serde_json::to_string_pretty(&config).unwrap(); - - let infra = Arc::new(MockInfra::new(config_path.clone())); - infra.files.lock().unwrap().insert(config_path, content); - - (AppConfigRepositoryImpl::new(infra), temp_dir) - } - - #[tokio::test] - async fn test_get_app_config_exists() { + #[test] + fn test_empty_forge_config_produces_empty_app_config() { + let fixture = ForgeConfig::default(); + let actual = forge_config_to_app_config(fixture); let expected = AppConfig::default(); - let (repo, _temp_dir) = repository_with_config_fixture(); - - let actual = repo.get_app_config().await.unwrap(); - assert_eq!(actual, expected); } - #[tokio::test] - async fn test_get_app_config_not_exists() { - let (repo, _temp_dir) = repository_fixture(); - - let actual = repo.get_app_config().await.unwrap(); - - // Should return default config when file doesn't exist - let expected = AppConfig::default(); + #[test] + fn test_full_login_info_is_mapped() { + let fixture = ForgeConfig::default() + .api_key("key-abc".to_string()) + .api_key_name("My Key".to_string()) + .api_key_masked("key-***".to_string()) + .email("user@example.com".to_string()) + .name("Alice".to_string()) + .auth_provider_id("github".to_string()); + let actual = forge_config_to_app_config(fixture); + let expected = AppConfig { + key_info: Some(LoginInfo { + api_key: "key-abc".to_string(), + api_key_name: "My Key".to_string(), + api_key_masked: "key-***".to_string(), + email: Some("user@example.com".to_string()), + name: Some("Alice".to_string()), + auth_provider_id: Some("github".to_string()), + }), + ..Default::default() + }; assert_eq!(actual, expected); } - #[tokio::test] - async fn test_set_app_config() { - let fixture = AppConfig::default(); - let (repo, _temp_dir) = repository_fixture(); - - let actual = repo.set_app_config(&fixture).await; - - assert!(actual.is_ok()); - - // Verify the config was actually written by reading it back - let read_config = repo.get_app_config().await.unwrap(); - assert_eq!(read_config, fixture); - } - - #[tokio::test] - async fn test_cache_behavior() { - let (repo, _temp_dir) = repository_with_config_fixture(); - - // First read should populate cache - let first_read = repo.get_app_config().await.unwrap(); - - // Second read should use cache (no file system access) - let second_read = repo.get_app_config().await.unwrap(); - assert_eq!(first_read, second_read); - - // Write new config should bust cache - let new_config = AppConfig::default(); - repo.set_app_config(&new_config).await.unwrap(); - - // Next read should get fresh data - let third_read = repo.get_app_config().await.unwrap(); - assert_eq!(third_read, new_config); + #[test] + fn test_session_with_provider_and_model() { + let fixture = ForgeConfig { + session: Some( + ModelConfig::default() + .provider_id("anthropic".to_string()) + .model_id("claude-3".to_string()), + ), + ..Default::default() + }; + let actual = forge_config_to_app_config(fixture); + let provider = ProviderId::from("anthropic".to_string()); + let expected = AppConfig { + provider: Some(provider.clone()), + model: HashMap::from([(provider, ModelId::new("claude-3"))]), + ..Default::default() + }; + assert_eq!(actual, expected); } - #[tokio::test] - async fn test_read_handles_custom_provider() { - let fixture = r#"{ - "provider": "xyz", - "model": {} - }"#; - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - - let infra = Arc::new(MockInfra::new(config_path.clone())); - infra - .files - .lock() - .unwrap() - .insert(config_path, fixture.to_string()); - - let repo = AppConfigRepositoryImpl::new(infra); - - let actual = repo.get_app_config().await.unwrap(); - + #[test] + fn test_session_with_only_provider_leaves_model_map_empty() { + let fixture = ForgeConfig { + session: Some(ModelConfig::default().provider_id("openai".to_string())), + ..Default::default() + }; + let actual = forge_config_to_app_config(fixture); let expected = AppConfig { - provider: Some(ProviderId::from_str("xyz").unwrap()), + provider: Some(ProviderId::from("openai".to_string())), + model: HashMap::new(), ..Default::default() }; assert_eq!(actual, expected); } - #[tokio::test] - async fn test_read_returns_default_if_not_exists() { - let (repo, _temp_dir) = repository_fixture(); - - let config = repo.get_app_config().await.unwrap(); - - // Config should be the default - assert_eq!(config, AppConfig::default()); + #[test] + fn test_commit_config_is_mapped() { + let fixture = ForgeConfig { + commit: Some( + ModelConfig::default() + .provider_id("openai".to_string()) + .model_id("gpt-4o".to_string()), + ), + ..Default::default() + }; + let actual = forge_config_to_app_config(fixture); + let expected = AppConfig { + commit: Some(CommitConfig { + provider: Some(ProviderId::from("openai".to_string())), + model: Some(ModelId::new("gpt-4o")), + }), + ..Default::default() + }; + assert_eq!(actual, expected); } - #[tokio::test] - async fn test_override_model() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - - // Set up a config with a specific model - let mut config = AppConfig::default(); - config.model.insert( - ProviderId::ANTHROPIC, - ModelId::new("claude-3-5-sonnet-20241022"), - ); - let content = serde_json::to_string_pretty(&config).unwrap(); - - let infra = Arc::new(MockInfra::new(config_path.clone())); - infra.files.lock().unwrap().insert(config_path, content); - - let repo = - AppConfigRepositoryImpl::new(infra).override_model(ModelId::new("override-model")); - let actual = repo.get_app_config().await.unwrap(); - - // The override model should be applied to all providers + #[test] + fn test_suggest_config_requires_both_provider_and_model() { + let fixture_provider_only = ForgeConfig { + suggest: Some(ModelConfig::default().provider_id("openai".to_string())), + ..Default::default() + }; assert_eq!( - actual.model.get(&ProviderId::ANTHROPIC), - Some(&ModelId::new("override-model")) + forge_config_to_app_config(fixture_provider_only).suggest, + None ); - } - - #[tokio::test] - async fn test_override_provider() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - - // Set up a config with a specific provider - let config = AppConfig { provider: Some(ProviderId::ANTHROPIC), ..Default::default() }; - let content = serde_json::to_string_pretty(&config).unwrap(); - - let infra = Arc::new(MockInfra::new(config_path.clone())); - infra.files.lock().unwrap().insert(config_path, content); - - let repo = AppConfigRepositoryImpl::new(infra).override_provider(ProviderId::OPENAI); - let actual = repo.get_app_config().await.unwrap(); - // The override provider should be applied - assert_eq!(actual.provider, Some(ProviderId::OPENAI)); + let fixture_model_only = ForgeConfig { + suggest: Some(ModelConfig { model_id: Some("gpt-4o".to_string()), provider_id: None }), + ..Default::default() + }; + assert_eq!(forge_config_to_app_config(fixture_model_only).suggest, None); } - #[tokio::test] - async fn test_override_prevents_config_write() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - - let infra = Arc::new(MockInfra::new(config_path)); - let repo = - AppConfigRepositoryImpl::new(infra).override_model(ModelId::new("override-model")); - - // Attempting to write config when override is set should fail - let config = AppConfig::default(); - let actual = repo.set_app_config(&config).await; - - assert!(actual.is_err()); - assert!( - actual - .unwrap_err() - .to_string() - .contains("Model or Provider was overridden") - ); + #[test] + fn test_suggest_config_with_both_fields_is_mapped() { + let fixture = ForgeConfig { + suggest: Some( + ModelConfig::default() + .provider_id("openai".to_string()) + .model_id("gpt-4o-mini".to_string()), + ), + ..Default::default() + }; + let actual = forge_config_to_app_config(fixture); + let expected = AppConfig { + suggest: Some(SuggestConfig { + provider: ProviderId::from("openai".to_string()), + model: ModelId::new("gpt-4o-mini"), + }), + ..Default::default() + }; + assert_eq!(actual, expected); } - #[tokio::test] - async fn test_provider_override_applied_with_no_config() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - let expected = ProviderId::from_str("open_router").unwrap(); - - let infra = Arc::new(MockInfra::new(config_path)); - let repo = AppConfigRepositoryImpl::new(infra) - .override_provider(expected.clone()) - .override_model(ModelId::new("test-model")); - - let actual = repo.get_app_config().await.unwrap(); - - assert_eq!(actual.provider, Some(expected)); + // ── apply_op ────────────────────────────────────────────────────────────── + + #[test] + fn test_apply_op_key_info_some_sets_all_fields() { + let mut fixture = ForgeConfig::default(); + let login = LoginInfo { + api_key: "key-123".to_string(), + api_key_name: "prod".to_string(), + api_key_masked: "key-***".to_string(), + email: Some("dev@forge.dev".to_string()), + name: Some("Bob".to_string()), + auth_provider_id: Some("google".to_string()), + }; + apply_op(AppConfigOperation::KeyInfo(Some(login)), &mut fixture); + let expected = ForgeConfig::default() + .api_key("key-123".to_string()) + .api_key_name("prod".to_string()) + .api_key_masked("key-***".to_string()) + .email("dev@forge.dev".to_string()) + .name("Bob".to_string()) + .auth_provider_id("google".to_string()); + assert_eq!(fixture, expected); } - #[tokio::test] - async fn test_model_override_applied_with_no_config() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - let provider = ProviderId::OPENAI; - let expected = ModelId::new("gpt-4-test"); - - let infra = Arc::new(MockInfra::new(config_path)); - let repo = AppConfigRepositoryImpl::new(infra) - .override_provider(provider.clone()) - .override_model(expected.clone()); - - let actual = repo.get_app_config().await.unwrap(); - - assert_eq!(actual.model.get(&provider), Some(&expected)); + #[test] + fn test_apply_op_key_info_none_clears_all_fields() { + let mut fixture = ForgeConfig::default() + .api_key("key-abc".to_string()) + .api_key_name("old".to_string()) + .api_key_masked("old-***".to_string()) + .email("old@example.com".to_string()) + .name("Old Name".to_string()) + .auth_provider_id("github".to_string()); + apply_op(AppConfigOperation::KeyInfo(None), &mut fixture); + assert_eq!(fixture, ForgeConfig::default()); } - #[tokio::test] - async fn test_provider_override_on_cached_config() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - let expected = ProviderId::ANTHROPIC; - - let infra = Arc::new(MockInfra::new(config_path)); - let repo = AppConfigRepositoryImpl::new(infra) - .override_provider(expected.clone()) - .override_model(ModelId::new("test-model")); - - // First call populates cache - repo.get_app_config().await.unwrap(); - - // Second call should still apply override to cached config - let actual = repo.get_app_config().await.unwrap(); - - assert_eq!(actual.provider, Some(expected)); + #[test] + fn test_apply_op_set_provider_creates_session_when_absent() { + let mut fixture = ForgeConfig::default(); + apply_op( + AppConfigOperation::SetProvider(ProviderId::from("anthropic".to_string())), + &mut fixture, + ); + let expected = ForgeConfig { + session: Some(ModelConfig::default().provider_id("anthropic".to_string())), + ..Default::default() + }; + assert_eq!(fixture, expected); } - #[tokio::test] - async fn test_model_override_on_cached_config() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - let provider = ProviderId::OPENAI; - let expected = ModelId::new("gpt-4-cached"); - - let infra = Arc::new(MockInfra::new(config_path)); - let repo = AppConfigRepositoryImpl::new(infra) - .override_provider(provider.clone()) - .override_model(expected.clone()); - - // First call populates cache - repo.get_app_config().await.unwrap(); - - // Second call should still apply override to cached config - let actual = repo.get_app_config().await.unwrap(); - - assert_eq!(actual.model.get(&provider), Some(&expected)); + #[test] + fn test_apply_op_set_provider_updates_existing_session_keeping_model() { + let mut fixture = ForgeConfig { + session: Some( + ModelConfig::default() + .provider_id("openai".to_string()) + .model_id("gpt-4".to_string()), + ), + ..Default::default() + }; + apply_op( + AppConfigOperation::SetProvider(ProviderId::from("anthropic".to_string())), + &mut fixture, + ); + let expected = ForgeConfig { + session: Some( + ModelConfig::default() + .provider_id("anthropic".to_string()) + .model_id("gpt-4".to_string()), + ), + ..Default::default() + }; + assert_eq!(fixture, expected); } - #[tokio::test] - async fn test_model_override_with_existing_provider() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - let expected = ModelId::new("override-model"); - - // Set up config with provider but no model - let config = AppConfig { provider: Some(ProviderId::ANTHROPIC), ..Default::default() }; - let content = serde_json::to_string_pretty(&config).unwrap(); - - let infra = Arc::new(MockInfra::new(config_path.clone())); - infra.files.lock().unwrap().insert(config_path, content); - - let repo = AppConfigRepositoryImpl::new(infra).override_model(expected.clone()); - let actual = repo.get_app_config().await.unwrap(); - - assert_eq!(actual.model.get(&ProviderId::ANTHROPIC), Some(&expected)); + #[test] + fn test_apply_op_set_model_for_matching_provider_updates_model() { + let mut fixture = ForgeConfig { + session: Some( + ModelConfig::default() + .provider_id("openai".to_string()) + .model_id("gpt-3.5".to_string()), + ), + ..Default::default() + }; + apply_op( + AppConfigOperation::SetModel( + ProviderId::from("openai".to_string()), + ModelId::new("gpt-4"), + ), + &mut fixture, + ); + let expected = ForgeConfig { + session: Some( + ModelConfig::default() + .provider_id("openai".to_string()) + .model_id("gpt-4".to_string()), + ), + ..Default::default() + }; + assert_eq!(fixture, expected); } - #[tokio::test] - async fn test_read_repairs_invalid_json() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - - // Invalid JSON with trailing comma - let json = r#"{"provider": "openai",}"#; - - let infra = Arc::new(MockInfra::new(config_path.clone())); - infra - .files - .lock() - .unwrap() - .insert(config_path, json.to_string()); - - let repo = AppConfigRepositoryImpl::new(infra); - let actual = repo.get_app_config().await.unwrap(); - - assert_eq!(actual.provider, Some(ProviderId::OPENAI)); + #[test] + fn test_apply_op_set_model_for_different_provider_replaces_session() { + let mut fixture = ForgeConfig { + session: Some( + ModelConfig::default() + .provider_id("openai".to_string()) + .model_id("gpt-4".to_string()), + ), + ..Default::default() + }; + apply_op( + AppConfigOperation::SetModel( + ProviderId::from("anthropic".to_string()), + ModelId::new("claude-3"), + ), + &mut fixture, + ); + let expected = ForgeConfig { + session: Some( + ModelConfig::default() + .provider_id("anthropic".to_string()) + .model_id("claude-3".to_string()), + ), + ..Default::default() + }; + assert_eq!(fixture, expected); } - #[tokio::test] - async fn test_read_returns_default_on_unrepairable_json() { - let temp_dir = tempfile::tempdir().unwrap(); - let config_path = temp_dir.path().join(".config.json"); - - // JSON that can't be repaired to AppConfig - let json = r#"["this", "is", "an", "array"]"#; - - let infra = Arc::new(MockInfra::new(config_path.clone())); - infra - .files - .lock() - .unwrap() - .insert(config_path, json.to_string()); - - let repo = AppConfigRepositoryImpl::new(infra); - let actual = repo.get_app_config().await.unwrap(); + #[test] + fn test_apply_op_set_commit_config() { + let mut fixture = ForgeConfig::default(); + let commit = CommitConfig::default() + .provider(ProviderId::from("openai".to_string())) + .model(ModelId::new("gpt-4o")); + apply_op(AppConfigOperation::SetCommitConfig(commit), &mut fixture); + let expected = ForgeConfig { + commit: Some( + ModelConfig::default() + .provider_id("openai".to_string()) + .model_id("gpt-4o".to_string()), + ), + ..Default::default() + }; + assert_eq!(fixture, expected); + } - assert_eq!(actual, AppConfig::default()); + #[test] + fn test_apply_op_set_suggest_config() { + let mut fixture = ForgeConfig::default(); + let suggest = SuggestConfig { + provider: ProviderId::from("anthropic".to_string()), + model: ModelId::new("claude-3-haiku"), + }; + apply_op(AppConfigOperation::SetSuggestConfig(suggest), &mut fixture); + let expected = ForgeConfig { + suggest: Some( + ModelConfig::default() + .provider_id("anthropic".to_string()) + .model_id("claude-3-haiku".to_string()), + ), + ..Default::default() + }; + assert_eq!(fixture, expected); } } diff --git a/crates/forge_repo/src/forge_repo.rs b/crates/forge_repo/src/forge_repo.rs index dde20149e0..c0b101c14f 100644 --- a/crates/forge_repo/src/forge_repo.rs +++ b/crates/forge_repo/src/forge_repo.rs @@ -9,11 +9,11 @@ use forge_app::{ KVStore, McpServerInfra, StrategyFactory, UserInfra, WalkedFile, Walker, WalkerInfra, }; use forge_domain::{ - AnyProvider, AppConfig, AppConfigRepository, AuthCredential, ChatCompletionMessage, - ChatRepository, CommandOutput, Context, Conversation, ConversationId, ConversationRepository, - Environment, FileInfo, FuzzySearchRepository, McpServerConfig, MigrationResult, Model, ModelId, - Provider, ProviderId, ProviderRepository, ResultStream, SearchMatch, Skill, SkillRepository, - Snapshot, SnapshotRepository, + AnyProvider, AppConfig, AppConfigOperation, AppConfigRepository, AuthCredential, + ChatCompletionMessage, ChatRepository, CommandOutput, Context, Conversation, ConversationId, + ConversationRepository, Environment, FileInfo, FuzzySearchRepository, McpServerConfig, + MigrationResult, Model, ModelId, Provider, ProviderId, ProviderRepository, ResultStream, + SearchMatch, Skill, SkillRepository, Snapshot, SnapshotRepository, }; // Re-export CacacheStorage from forge_infra pub use forge_infra::CacacheStorage; @@ -23,7 +23,7 @@ use reqwest_eventsource::EventSource; use url::Url; use crate::agent::ForgeAgentRepository; -use crate::app_config::AppConfigRepositoryImpl; +use crate::app_config::ForgeConfigRepository; use crate::context_engine::ForgeContextEngineRepository; use crate::conversation::ConversationRepositoryImpl; use crate::database::{DatabasePool, PoolConfig}; @@ -42,7 +42,7 @@ pub struct ForgeRepo { infra: Arc, file_snapshot_service: Arc, conversation_repository: Arc, - app_config_repository: Arc>, + config_repository: Arc, mcp_cache_repository: Arc, provider_repository: Arc>, chat_repository: Arc>, @@ -54,11 +54,7 @@ pub struct ForgeRepo { } impl ForgeRepo { - pub fn new( - infra: Arc, - override_model: Option, - override_provider: Option, - ) -> Self { + pub fn new(infra: Arc) -> Self { let env = infra.get_environment(); let file_snapshot_service = Arc::new(ForgeFileSnapshotService::new(env.clone())); let db_pool = @@ -68,11 +64,7 @@ impl AppConfigRepository - for ForgeRepo -{ +impl AppConfigRepository for ForgeRepo { async fn get_app_config(&self) -> anyhow::Result { - self.app_config_repository.get_app_config().await + self.config_repository.get_app_config().await } - async fn set_app_config(&self, config: &AppConfig) -> anyhow::Result<()> { - self.app_config_repository.set_app_config(config).await + async fn update_app_config(&self, ops: Vec) -> anyhow::Result<()> { + self.config_repository.update_app_config(ops).await } } diff --git a/crates/forge_services/src/app_config.rs b/crates/forge_services/src/app_config.rs index 7e7e851c1f..d4737c27e6 100644 --- a/crates/forge_services/src/app_config.rs +++ b/crates/forge_services/src/app_config.rs @@ -1,7 +1,10 @@ use std::sync::Arc; use forge_app::AppConfigService; -use forge_domain::{AppConfig, AppConfigRepository, ModelId, ProviderId, ProviderRepository}; +use forge_domain::{ + AppConfigOperation, AppConfigRepository, ModelId, ProviderId, ProviderRepository, +}; +use tracing::debug; /// Service for managing user preferences for default providers and models. pub struct ForgeAppConfigService { @@ -16,15 +19,10 @@ impl ForgeAppConfigService { } impl ForgeAppConfigService { - /// Helper method to update app configuration atomically. - async fn update(&self, updater: U) -> anyhow::Result<()> - where - U: FnOnce(&mut AppConfig), - { - let mut config = self.infra.get_app_config().await?; - updater(&mut config); - self.infra.set_app_config(&config).await?; - Ok(()) + /// Helper method to apply a config operation atomically. + async fn update(&self, op: AppConfigOperation) -> anyhow::Result<()> { + debug!(op = ?op, "Updating app config"); + self.infra.update_app_config(vec![op]).await } } @@ -40,10 +38,8 @@ impl AppConfigService } async fn set_default_provider(&self, provider_id: ProviderId) -> anyhow::Result<()> { - self.update(|config| { - config.provider = Some(provider_id); - }) - .await + self.update(AppConfigOperation::SetProvider(provider_id)) + .await } async fn get_provider_model( @@ -75,10 +71,8 @@ impl AppConfigService .provider .ok_or(forge_domain::Error::NoDefaultProvider)?; - self.update(|config| { - config.model.insert(provider_id, model.clone()); - }) - .await + self.update(AppConfigOperation::SetModel(provider_id, model)) + .await } async fn get_commit_config(&self) -> anyhow::Result> { @@ -90,10 +84,8 @@ impl AppConfigService &self, commit_config: forge_domain::CommitConfig, ) -> anyhow::Result<()> { - self.update(|config| { - config.commit = Some(commit_config); - }) - .await + self.update(AppConfigOperation::SetCommitConfig(commit_config)) + .await } async fn get_suggest_config(&self) -> anyhow::Result> { @@ -105,10 +97,8 @@ impl AppConfigService &self, suggest_config: forge_domain::SuggestConfig, ) -> anyhow::Result<()> { - self.update(|config| { - config.suggest = Some(suggest_config); - }) - .await + self.update(AppConfigOperation::SetSuggestConfig(suggest_config)) + .await } } @@ -118,8 +108,8 @@ mod tests { use std::sync::Mutex; use forge_domain::{ - AnyProvider, AppConfig, ChatRepository, InputModality, MigrationResult, Model, ModelSource, - Provider, ProviderId, ProviderResponse, ProviderTemplate, + AnyProvider, AppConfig, AppConfigOperation, ChatRepository, InputModality, MigrationResult, + Model, ModelSource, Provider, ProviderId, ProviderResponse, ProviderTemplate, }; use pretty_assertions::assert_eq; use url::Url; @@ -200,8 +190,19 @@ mod tests { Ok(self.app_config.lock().unwrap().clone()) } - async fn set_app_config(&self, config: &AppConfig) -> anyhow::Result<()> { - *self.app_config.lock().unwrap() = config.clone(); + async fn update_app_config(&self, ops: Vec) -> anyhow::Result<()> { + let mut config = self.app_config.lock().unwrap(); + for op in ops { + match op { + AppConfigOperation::KeyInfo(info) => config.key_info = info, + AppConfigOperation::SetProvider(pid) => config.provider = Some(pid), + AppConfigOperation::SetModel(pid, mid) => { + config.model.insert(pid, mid); + } + AppConfigOperation::SetCommitConfig(commit) => config.commit = Some(commit), + AppConfigOperation::SetSuggestConfig(suggest) => config.suggest = Some(suggest), + } + } Ok(()) } } diff --git a/crates/forge_services/src/auth.rs b/crates/forge_services/src/auth.rs index 93a775ac62..6d0dee37bf 100644 --- a/crates/forge_services/src/auth.rs +++ b/crates/forge_services/src/auth.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::bail; use bytes::Bytes; use forge_app::{AuthService, EnvironmentInfra, Error, HttpInfra, User, UserUsage}; -use forge_domain::{AppConfigRepository, InitAuth, LoginInfo}; +use forge_domain::{AppConfigOperation, AppConfigRepository, InitAuth, LoginInfo}; use reqwest::Url; use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue}; @@ -102,10 +102,9 @@ impl ForgeAuthService } async fn set_auth_token(&self, login: Option) -> anyhow::Result<()> { - let mut config = self.infra.get_app_config().await?; - config.key_info = login; - self.infra.set_app_config(&config).await?; - Ok(()) + self.infra + .update_app_config(vec![AppConfigOperation::KeyInfo(login)]) + .await } } diff --git a/crates/forge_tracker/src/log.rs b/crates/forge_tracker/src/log.rs index a2e97c1ec7..df4cda283d 100644 --- a/crates/forge_tracker/src/log.rs +++ b/crates/forge_tracker/src/log.rs @@ -33,7 +33,6 @@ pub fn init_tracing(log_path: PathBuf, tracker: Tracker) -> anyhow::Result EDITOR > nano + local editor_cmd="${FORGE_EDITOR:-${EDITOR:-nano}}" + + # Validate editor exists + if ! command -v "${editor_cmd%% *}" &>/dev/null; then + _forge_log error "Editor not found: $editor_cmd (set FORGE_EDITOR or EDITOR)" + return 1 + fi + + local config_file="${HOME}/forge/.forge.toml" + + # Ensure the config directory exists + if [[ ! -d "${HOME}/forge" ]]; then + mkdir -p "${HOME}/forge" || { + _forge_log error "Failed to create ~/forge directory" + return 1 + } + fi + + # Create the config file if it does not yet exist + if [[ ! -f "$config_file" ]]; then + touch "$config_file" || { + _forge_log error "Failed to create $config_file" + return 1 + } + fi + + # Open editor with its own TTY session + (eval "$editor_cmd '$config_file'" /dev/tty 2>&1) + local exit_code=$? + + if [[ $exit_code -ne 0 ]]; then + _forge_log error "Editor exited with error code $exit_code" + fi + + _forge_reset +} + # Action handler: Show tools function _forge_action_tools() { echo diff --git a/shell-plugin/lib/dispatcher.zsh b/shell-plugin/lib/dispatcher.zsh index 1f4b539b45..e162f28dcf 100644 --- a/shell-plugin/lib/dispatcher.zsh +++ b/shell-plugin/lib/dispatcher.zsh @@ -190,6 +190,9 @@ function forge-accept-line() { config) _forge_action_config ;; + config-edit|ce) + _forge_action_config_edit + ;; skill) _forge_action_skill ;; diff --git a/shell-plugin/lib/helpers.zsh b/shell-plugin/lib/helpers.zsh index e0a017282e..03bcece244 100644 --- a/shell-plugin/lib/helpers.zsh +++ b/shell-plugin/lib/helpers.zsh @@ -22,9 +22,9 @@ function _forge_exec() { local agent_id="${_FORGE_ACTIVE_AGENT:-forge}" local -a cmd cmd=($_FORGE_BIN --agent "$agent_id") - [[ -n "$_FORGE_SESSION_MODEL" ]] && cmd+=(--model "$_FORGE_SESSION_MODEL") - [[ -n "$_FORGE_SESSION_PROVIDER" ]] && cmd+=(--provider "$_FORGE_SESSION_PROVIDER") cmd+=("$@") + [[ -n "$_FORGE_SESSION_MODEL" ]] && local -x FORGE_SESSION__MODEL_ID="$_FORGE_SESSION_MODEL" + [[ -n "$_FORGE_SESSION_PROVIDER" ]] && local -x FORGE_SESSION__PROVIDER_ID="$_FORGE_SESSION_PROVIDER" "${cmd[@]}" } @@ -38,9 +38,9 @@ function _forge_exec_interactive() { local agent_id="${_FORGE_ACTIVE_AGENT:-forge}" local -a cmd cmd=($_FORGE_BIN --agent "$agent_id") - [[ -n "$_FORGE_SESSION_MODEL" ]] && cmd+=(--model "$_FORGE_SESSION_MODEL") - [[ -n "$_FORGE_SESSION_PROVIDER" ]] && cmd+=(--provider "$_FORGE_SESSION_PROVIDER") cmd+=("$@") + [[ -n "$_FORGE_SESSION_MODEL" ]] && local -x FORGE_SESSION__MODEL_ID="$_FORGE_SESSION_MODEL" + [[ -n "$_FORGE_SESSION_PROVIDER" ]] && local -x FORGE_SESSION__PROVIDER_ID="$_FORGE_SESSION_PROVIDER" "${cmd[@]}" /dev/tty }