diff --git a/crates/bashkit/src/scripted_tool/mod.rs b/crates/bashkit/src/scripted_tool/mod.rs index 7907fecb..e99f2a81 100644 --- a/crates/bashkit/src/scripted_tool/mod.rs +++ b/crates/bashkit/src/scripted_tool/mod.rs @@ -48,6 +48,73 @@ //! assert_eq!(resp.stdout.trim(), "hello Alice"); //! # }); //! ``` +//! +//! # Shared context across callbacks +//! +//! When multiple tool callbacks need shared resources (HTTP clients, auth tokens, +//! config), use the standard Rust closure-capture pattern with `Arc`: +//! +//! ```rust +//! use bashkit::{ScriptedTool, ToolArgs, ToolDef}; +//! use std::sync::Arc; +//! +//! let api_key = Arc::new("sk-secret-key".to_string()); +//! let base_url = Arc::new("https://api.example.com".to_string()); +//! +//! let k = api_key.clone(); +//! let u = base_url.clone(); +//! let mut builder = ScriptedTool::builder("api"); +//! builder = builder.tool( +//! ToolDef::new("get_user", "Fetch user by ID"), +//! move |args: &ToolArgs| { +//! let _key = &*k; // shared API key +//! let _url = &*u; // shared base URL +//! Ok(format!("{{\"id\":1}}\n")) +//! }, +//! ); +//! +//! let k2 = api_key.clone(); +//! let u2 = base_url.clone(); +//! builder = builder.tool( +//! ToolDef::new("list_orders", "List orders"), +//! move |_args: &ToolArgs| { +//! let _key = &*k2; +//! let _url = &*u2; +//! Ok(format!("[]\n")) +//! }, +//! ); +//! let _tool = builder.build(); +//! ``` +//! +//! For mutable shared state, use `Arc>`: +//! +//! ```rust +//! use bashkit::{ScriptedTool, ToolArgs, ToolDef}; +//! use std::sync::{Arc, Mutex}; +//! +//! let call_count = Arc::new(Mutex::new(0u64)); +//! let c = call_count.clone(); +//! let tool = ScriptedTool::builder("api") +//! .tool( +//! ToolDef::new("tracked", "Counted call"), +//! move |_args: &ToolArgs| { +//! let mut count = c.lock().unwrap(); +//! *count += 1; +//! Ok(format!("call #{count}\n")) +//! }, +//! ) +//! .build(); +//! ``` +//! +//! # State across execute() calls +//! +//! Each `execute()` creates a fresh Bash interpreter — no state carries over. +//! This is a security feature (clean sandbox per call). The LLM carries state +//! between calls via its context window: it sees stdout from each call and can +//! pass relevant data from one call's output into the next call's script. +//! +//! For persistent state across calls via callbacks, use `Arc` in closures — +//! the same `Arc` instances are reused across `execute()` calls. mod execute; @@ -682,4 +749,140 @@ mod tests { assert_eq!(parsed["name"], "Alice"); assert_eq!(parsed["count"], "3"); // string, not int — no schema } + + // -- Shared context tests (#522) -- + + #[tokio::test] + async fn test_shared_arc_across_callbacks() { + use std::sync::{Arc, Mutex}; + + let shared = Arc::new("shared-token".to_string()); + let call_log = Arc::new(Mutex::new(Vec::::new())); + + let s1 = shared.clone(); + let log1 = call_log.clone(); + let s2 = shared.clone(); + let log2 = call_log.clone(); + + let mut tool = ScriptedTool::builder("ctx_test") + .tool( + ToolDef::new("tool_a", "First tool"), + move |_args: &ToolArgs| { + log1.lock().expect("lock").push(format!("a:{}", *s1)); + Ok("a\n".to_string()) + }, + ) + .tool( + ToolDef::new("tool_b", "Second tool"), + move |_args: &ToolArgs| { + log2.lock().expect("lock").push(format!("b:{}", *s2)); + Ok("b\n".to_string()) + }, + ) + .build(); + + let resp = tool + .execute(ToolRequest { + commands: "tool_a && tool_b".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp.exit_code, 0); + let log = call_log.lock().expect("lock"); + assert_eq!(*log, vec!["a:shared-token", "b:shared-token"]); + } + + #[tokio::test] + async fn test_mutable_shared_state_across_callbacks() { + use std::sync::{Arc, Mutex}; + + let counter = Arc::new(Mutex::new(0u64)); + let c = counter.clone(); + + let mut tool = ScriptedTool::builder("mut_test") + .tool( + ToolDef::new("increment", "Bump counter"), + move |_args: &ToolArgs| { + let mut count = c.lock().expect("lock"); + *count += 1; + Ok(format!("{count}\n")) + }, + ) + .build(); + + let resp = tool + .execute(ToolRequest { + commands: "increment; increment; increment".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp.exit_code, 0); + assert_eq!(*counter.lock().expect("lock"), 3); + } + + // -- Fresh interpreter isolation test (#524) -- + + #[tokio::test] + async fn test_fresh_interpreter_per_execute() { + let mut tool = ScriptedTool::builder("isolation_test") + .tool(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| { + Ok("ok\n".to_string()) + }) + .build(); + + // Set a variable in call 1 + let resp1 = tool + .execute(ToolRequest { + commands: "export MY_VAR=hello; echo $MY_VAR".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp1.stdout.trim(), "hello"); + + // Variable should NOT persist to call 2 + let resp2 = tool + .execute(ToolRequest { + commands: "echo \">${MY_VAR}<\"".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp2.stdout.trim(), "><"); + } + + #[tokio::test] + async fn test_arc_callback_persists_across_execute_calls() { + use std::sync::{Arc, Mutex}; + + let counter = Arc::new(Mutex::new(0u64)); + let c = counter.clone(); + + let mut tool = ScriptedTool::builder("persist_test") + .tool( + ToolDef::new("count", "Count calls"), + move |_args: &ToolArgs| { + let mut n = c.lock().expect("lock"); + *n += 1; + Ok(format!("{n}\n")) + }, + ) + .build(); + + // Call 1 + let resp1 = tool + .execute(ToolRequest { + commands: "count".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp1.stdout.trim(), "1"); + + // Call 2 — counter persists via Arc + let resp2 = tool + .execute(ToolRequest { + commands: "count".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp2.stdout.trim(), "2"); + } } diff --git a/specs/014-scripted-tool-orchestration.md b/specs/014-scripted-tool-orchestration.md index a1ff395a..3e251ec3 100644 --- a/specs/014-scripted-tool-orchestration.md +++ b/specs/014-scripted-tool-orchestration.md @@ -139,6 +139,30 @@ Output: {stdout, stderr, exit_code} - Use variables to pass data between tool calls ``` +### Shared context across callbacks + +Use the standard Rust closure-capture pattern with `Arc` to share resources: + +```rust +let client = Arc::new(build_authenticated_client()); +let c = client.clone(); +builder.tool(ToolDef::new("get_user", "..."), move |args| { + let resp = c.get(&format!("/users/{}", args.param_i64("id").unwrap())); + Ok(resp.text()?) +}); +``` + +For mutable state, use `Arc>`. No API change needed — closures handle it naturally. + +### State across execute() calls + +Each `execute()` creates a fresh Bash interpreter (security: clean sandbox per call). +The LLM carries state via its context window — it sees stdout from each call and passes +relevant data into the next script. + +For callback-level persistence, `Arc` state in closures persists across `execute()` calls +since the same `Arc` instances are reused. + ## Module location `crates/bashkit/src/scripted_tool/` @@ -160,7 +184,7 @@ Run: `cargo run --example scripted_tool --features scripted_tool` ## Test coverage -31 unit tests covering: +35 unit tests covering: - Builder configuration (name, description, defaults) - Introspection (help, system_prompt, schemas, schema rendering) - Flag parsing (`--key value`, `--key=value`, boolean flags, type coercion) @@ -173,6 +197,8 @@ Run: `cargo run --example scripted_tool --features scripted_tool` - Environment variables - Status callbacks - Multiple sequential `execute()` calls (Arc reuse) +- Shared context: Arc across callbacks, mutable Arc> +- Interpreter isolation: fresh per execute(), Arc callback persistence ## Security