diff --git a/crates/pctx_code_execution_runtime/src/callback_registry.rs b/crates/pctx_code_execution_runtime/src/callback_registry.rs index 8bef6e7..fbeecd4 100644 --- a/crates/pctx_code_execution_runtime/src/callback_registry.rs +++ b/crates/pctx_code_execution_runtime/src/callback_registry.rs @@ -1,9 +1,11 @@ +use serde_json::json; use std::{ collections::HashMap, future::Future, pin::Pin, sync::{Arc, RwLock}, }; +use tracing::instrument; use crate::error::McpError; @@ -103,6 +105,13 @@ impl CallbackRegistry { /// /// This function will return an error if a callback by the provided id doesn't exist /// or if the callback itself fails + #[instrument( + name = "invoke_callback_tool", + skip_all, + fields(id=id, args = json!(args).to_string()), + ret(Display), + err + )] pub async fn invoke( &self, id: &str, diff --git a/crates/pctx_code_execution_runtime/src/mcp_registry.rs b/crates/pctx_code_execution_runtime/src/mcp_registry.rs index 4cf2476..869f2de 100644 --- a/crates/pctx_code_execution_runtime/src/mcp_registry.rs +++ b/crates/pctx_code_execution_runtime/src/mcp_registry.rs @@ -1,9 +1,10 @@ use crate::error::McpError; use pctx_config::server::ServerConfig; use rmcp::model::{CallToolRequestParam, JsonObject, RawContent}; +use serde_json::json; use std::collections::HashMap; use std::sync::{Arc, RwLock}; -use tracing::warn; +use tracing::{info, instrument, warn}; /// Singleton registry for MCP server configurations #[derive(Clone)] @@ -87,6 +88,13 @@ impl Default for MCPRegistry { } /// Call an MCP tool on a registered server +#[instrument( + name = "invoke_mcp_tool", + skip_all, + fields(id=format!("{server_name}.{tool_name}"), args = json!(args).to_string()), + ret(Display), + err +)] pub(crate) async fn call_mcp_tool( registry: &MCPRegistry, server_name: &str, @@ -132,22 +140,22 @@ pub(crate) async fn call_mcp_tool( } // Prefer structuredContent if available, otherwise use content array - if let Some(structured) = tool_result.structured_content { - return Ok(structured); - } - - // Convert content to JSON value - // For simplicity, we'll extract text content and try to parse as JSON - if let Some(RawContent::Text(text_content)) = tool_result.content.first().map(|a| &**a) { + let has_structured = tool_result.structured_content.is_some(); + let val = if let Some(structured) = tool_result.structured_content { + structured + } else if let Some(RawContent::Text(text_content)) = tool_result.content.first().map(|a| &**a) { // Try to parse as JSON, fallback to string value serde_json::from_str(&text_content.text) .or_else(|_| Ok(serde_json::Value::String(text_content.text.clone()))) .map_err(|e: serde_json::Error| { McpError::ToolCall(format!("Failed to parse content: {e}")) - }) + })? } else { // Return the whole content array as JSON - serde_json::to_value(&tool_result.content) - .map_err(|e| McpError::ToolCall(format!("Failed to serialize content: {e}"))) - } + json!(tool_result.content) + }; + + info!(structured_content = has_structured, result =? &val, "Tool result"); + + Ok(val) } diff --git a/crates/pctx_code_mode/src/code_mode.rs b/crates/pctx_code_mode/src/code_mode.rs index 156c909..16641e8 100644 --- a/crates/pctx_code_mode/src/code_mode.rs +++ b/crates/pctx_code_mode/src/code_mode.rs @@ -5,7 +5,7 @@ use pctx_code_execution_runtime::CallbackRegistry; use pctx_config::server::ServerConfig; use serde::{Deserialize, Serialize}; use serde_json::json; -use tracing::{debug, warn}; +use tracing::{debug, instrument, warn}; use crate::{ Error, Result, @@ -15,7 +15,7 @@ use crate::{ }, }; -#[derive(Clone, Default, Serialize, Deserialize)] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] pub struct CodeMode { // Codegen interfaces pub tool_sets: Vec, @@ -105,6 +105,7 @@ impl CodeMode { GetFunctionDetailsOutput { code, functions } } + #[instrument(skip(self, callback_registry), ret(Display), err)] pub async fn execute( &self, code: &str, @@ -160,7 +161,7 @@ impl CodeMode { namespaces = namespaces.join("\n\n"), ); - debug!("Executing code in sandbox"); + debug!(to_execute = %to_execute, "Executing code in sandbox"); let options = pctx_executor::ExecuteOptions::new() .with_allowed_hosts(self.allowed_hosts().into_iter().collect()) diff --git a/crates/pctx_code_mode/src/model.rs b/crates/pctx_code_mode/src/model.rs index 50e16e3..852d609 100644 --- a/crates/pctx_code_mode/src/model.rs +++ b/crates/pctx_code_mode/src/model.rs @@ -163,6 +163,11 @@ impl ExecuteOutput { ) } } +impl Display for ExecuteOutput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", json!(&self)) + } +} // -------------- Callbacks -------------- diff --git a/crates/pctx_executor/src/lib.rs b/crates/pctx_executor/src/lib.rs index de6dc95..aa8b7c6 100644 --- a/crates/pctx_executor/src/lib.rs +++ b/crates/pctx_executor/src/lib.rs @@ -131,15 +131,16 @@ pub async fn execute(code: &str, options: ExecuteOptions) -> Result, } /// Response to registering tools -#[derive(Debug, Serialize, Deserialize, ToSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct RegisterToolsResponse { pub registered: usize, } /// Request to register MCP servers -#[derive(Debug, Deserialize, ToSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct RegisterMcpServersRequest { pub servers: Vec, } // TODO: de-dup with pctx_config -#[derive(Debug, Deserialize, Clone, ToSchema)] +#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)] #[serde(untagged)] pub enum McpServerConfig { Http { @@ -114,20 +114,20 @@ pub enum McpServerConfig { } /// Response after registering MCP servers -#[derive(Debug, Serialize, Deserialize, ToSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct RegisterMcpServersResponse { pub registered: usize, pub failed: Vec, } /// Response after creating a new `CodeMode` session -#[derive(Debug, Serialize, Deserialize, ToSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct CreateSessionResponse { #[schema(value_type = String)] pub session_id: Uuid, } /// Response after closing a `CodeMode` session -#[derive(Debug, Serialize, Deserialize, ToSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct CloseSessionResponse { pub success: bool, } diff --git a/crates/pctx_session_server/src/routes.rs b/crates/pctx_session_server/src/routes.rs index d004966..ae23e8f 100644 --- a/crates/pctx_session_server/src/routes.rs +++ b/crates/pctx_session_server/src/routes.rs @@ -3,9 +3,11 @@ use axum::{Json, extract::State, http::StatusCode}; use pctx_code_mode::{ CodeMode, - model::{GetFunctionDetailsInput, GetFunctionDetailsOutput, ListFunctionsOutput}, + model::{ + CallbackConfig, GetFunctionDetailsInput, GetFunctionDetailsOutput, ListFunctionsOutput, + }, }; -use tracing::info; +use tracing::{debug, info}; use uuid::Uuid; use crate::extractors::CodeModeSession; @@ -46,7 +48,10 @@ pub(crate) async fn create_session( State(state): State>, ) -> ApiResult> { let session_id = Uuid::new_v4(); - info!("Creating new CodeMode session: {session_id}"); + info!( + session_id =? session_id, + "Creating new CodeMode session" + ); let code_mode = CodeMode::default(); state @@ -55,7 +60,10 @@ pub(crate) async fn create_session( .await .context("Failed inserting code mode session into backend")?; - info!("Created CodeMode session: {session_id}"); + info!( + session_id =? session_id, + "Created CodeMode session" + ); Ok(Json(CreateSessionResponse { session_id })) } @@ -78,7 +86,7 @@ pub(crate) async fn close_session( State(state): State>, CodeModeSession(session_id): CodeModeSession, ) -> ApiResult> { - info!("Closing CodeMode session: {session_id}"); + info!(session_id =? session_id, "Closing CodeMode session"); let existed = state .backend @@ -97,7 +105,7 @@ pub(crate) async fn close_session( )); } - info!("Closed CodeMode session: {session_id}"); + info!(session_id =? session_id, "Closed CodeMode session"); Ok(Json(CloseSessionResponse { success: true })) } @@ -119,7 +127,7 @@ pub(crate) async fn list_functions( State(state): State>, CodeModeSession(session_id): CodeModeSession, ) -> ApiResult> { - info!(session_id =? session_id, "Listing tools"); + info!(session_id =? session_id, "Listing functions"); let code_mode = state .backend @@ -168,7 +176,8 @@ pub(crate) async fn get_function_details( .join(", "); info!( session_id =? session_id, - "Getting function details for {requested_functions}" + functions =? requested_functions, + "Getting function details", ); let code_mode = state.backend.get(session_id).await?.ok_or(ApiError::new( @@ -205,9 +214,15 @@ pub(crate) async fn register_tools( CodeModeSession(session_id): CodeModeSession, Json(request): Json, ) -> ApiResult> { + let tool_ids = request + .tools + .iter() + .map(CallbackConfig::id) + .collect::>(); info!( - "Registering {} tools for session {session_id}", - request.tools.len(), + session_id =? session_id, + tools =? &tool_ids, + "Registering tools...", ); let mut code_mode = state @@ -224,19 +239,25 @@ pub(crate) async fn register_tools( }, ))?; - let mut registered = 0; for tool in &request.tools { + debug!(tool =? tool.id(), "Adding callback tool {}", tool.id()); code_mode .add_callback(tool) .context("Failed adding callback")?; - - registered += 1; } // Update the backend with the modified CodeMode state.backend.update(session_id, code_mode).await?; - Ok(Json(RegisterToolsResponse { registered })) + info!( + session_id =? session_id, + tools =? &tool_ids, + "Registered tools", + ); + + Ok(Json(RegisterToolsResponse { + registered: request.tools.len(), + })) } /// Register MCP servers dynamically at runtime @@ -303,6 +324,13 @@ pub(crate) async fn register_servers( .await .context("Failed updating code mode session in backend")?; + info!( + session_id =% session_id, + registered =% registered, + failed =? failed, + "Registered MCP servers", + ); + Ok(Json(RegisterMcpServersResponse { registered, failed })) } diff --git a/crates/pctx_session_server/src/state/backend.rs b/crates/pctx_session_server/src/state/backend.rs index 0631c4a..e84de88 100644 --- a/crates/pctx_session_server/src/state/backend.rs +++ b/crates/pctx_session_server/src/state/backend.rs @@ -2,7 +2,10 @@ use std::{collections::HashMap, sync::Arc}; use anyhow::{Context, Result}; use async_trait::async_trait; -use pctx_code_mode::CodeMode; +use pctx_code_mode::{ + CodeMode, + model::{ExecuteInput, ExecuteOutput}, +}; use tokio::sync::RwLock; use uuid::Uuid; @@ -30,6 +33,18 @@ pub trait PctxSessionBackend: Clone + Send + Sync + 'static { /// Returns a full list of active `CodeMode` sessions in the backend. async fn list_sessions(&self) -> Result>; + + /// Hook called after every code mode execution websocket event + async fn post_execution( + &self, + _session_id: Uuid, + _execution_id: Uuid, + _code_mode: CodeMode, + _execution_req: ExecuteInput, + _execution_res: Result, + ) -> Result<()> { + Ok(()) + } } /// Manages `CodeMode` sessions locally using thread-safe diff --git a/crates/pctx_session_server/src/websocket/handler.rs b/crates/pctx_session_server/src/websocket/handler.rs index 4a0ae90..11567cb 100644 --- a/crates/pctx_session_server/src/websocket/handler.rs +++ b/crates/pctx_session_server/src/websocket/handler.rs @@ -9,6 +9,7 @@ use crate::{ }, state::ws_manager::WsSession, }; +use anyhow::anyhow; use axum::{ extract::{ State, @@ -22,6 +23,7 @@ use futures::{ stream::{SplitSink, SplitStream}, }; use pctx_code_execution_runtime::{CallbackFn, CallbackRegistry}; +use pctx_code_mode::model::ExecuteInput; use rmcp::{ ErrorData, model::{ErrorCode, JsonRpcMessage, RequestId}, @@ -82,9 +84,7 @@ async fn handle_socket( state: AppState, code_mode_session: Uuid, ) { - info!("New WebSocket connection with code_mode_session: {code_mode_session}"); - - info!("Verified code mode session {code_mode_session} exists, proceeding with WebSocket setup"); + info!(session_id =? code_mode_session, "New WebSocket connection"); // Split socket into sender and receiver let (sender, receiver) = socket.split(); @@ -96,7 +96,9 @@ async fn handle_socket( let session = WsSession::new(tx.clone(), code_mode_session); let ws_session = session.id; - info!( + debug!( + session_id =? code_mode_session, + ws_session =? ws_session, "Created session {ws_session} connected to code mode session {}", session.code_mode_session_id ); @@ -168,7 +170,7 @@ async fn handle_execute_code_request( req_id: RequestId, params: ExecuteCodeParams, ws_session: Uuid, - state: &AppState, + state: AppState, ) -> Result<(), String> { // Save the WebSocket session for later response let ws_session_lock = state @@ -201,8 +203,9 @@ async fn handle_execute_code_request( debug!("Found CodeMode session with ID: {code_mode_session_id}"); - let callback_registry = CallbackRegistry::default(); + let execution_id = Uuid::new_v4(); + let callback_registry = CallbackRegistry::default(); for callback_cfg in &code_mode.callbacks { let ws_session_lock_clone = ws_session_lock.clone(); let cfg = callback_cfg.clone(); @@ -244,10 +247,18 @@ async fn handle_execute_code_request( } } + let execution_span = tracing::info_span!( + "execute_code_in_session", + session_id = %code_mode_session_id, + execution_id = %execution_id, + ); + tokio::spawn(async move { - let current_span = tracing::Span::current(); + let code_mode_clone = code_mode.clone(); + let code_clone = params.code.clone(); + let output = tokio::task::spawn_blocking(move || -> Result<_, anyhow::Error> { - let _guard = current_span.enter(); + let _guard = execution_span.enter(); let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -256,36 +267,59 @@ async fn handle_execute_code_request( // create callback registry to execute callback requests over the same ws which // initiated the request rt.block_on(async { - code_mode - .execute(¶ms.code, Some(callback_registry)) + code_mode_clone + .execute(&code_clone, Some(callback_registry)) .await .map_err(|e| anyhow::anyhow!("Execution error: {e}")) }) }) .await; - let msg = match output { - Ok(Ok(exec_output)) => { - WsJsonRpcMessage::response(PctxJsonRpcResponse::ExecuteCode(exec_output), req_id) - } - Ok(Err(e)) => WsJsonRpcMessage::error( - ErrorData { - code: ErrorCode::INTERNAL_ERROR, - message: format!("Execution failed: {e}").into(), - data: None, - }, - req_id, + let (msg, execution_res) = match output { + Ok(Ok(exec_output)) => ( + WsJsonRpcMessage::response( + PctxJsonRpcResponse::ExecuteCode(exec_output.clone()), + req_id, + ), + Ok(exec_output), ), - Err(e) => WsJsonRpcMessage::error( - ErrorData { - code: ErrorCode::INTERNAL_ERROR, - message: format!("Task join failed: {e}").into(), - data: None, - }, - req_id, + Ok(Err(e)) => ( + WsJsonRpcMessage::error( + ErrorData { + code: ErrorCode::INTERNAL_ERROR, + message: format!("Execution failed: {e}").into(), + data: None, + }, + req_id, + ), + Err(anyhow!(e)), + ), + Err(e) => ( + WsJsonRpcMessage::error( + ErrorData { + code: ErrorCode::INTERNAL_ERROR, + message: format!("Task join failed: {e}").into(), + data: None, + }, + req_id, + ), + Err(anyhow!(e)), ), }; + if let Err(e) = state + .backend + .post_execution( + code_mode_session_id, + execution_id, + code_mode, + ExecuteInput { code: params.code }, + execution_res, + ) + .await + { + error!("Failed to post_execution hook: {e}"); + } if let Err(e) = sender.send(msg) { error!("Failed to send execute_code response: {e}"); } @@ -312,7 +346,7 @@ async fn handle_message( JsonRpcMessage::Request(req) => match req.request { PctxJsonRpcRequest::ExecuteCode { params } => { debug!("Executing code..."); - handle_execute_code_request(req.id, params, ws_session, state).await + handle_execute_code_request(req.id, params, ws_session, state.clone()).await } PctxJsonRpcRequest::ExecuteTool { .. } => { // the server is only responsible for servicing execute_code requests, execute_tool @@ -349,7 +383,6 @@ async fn handle_message( } Message::Close(_) => { info!("Received close message for session {ws_session}"); - println!("CLOSING...."); Ok(()) } Message::Ping(_) | Message::Pong(_) => Ok(()), diff --git a/pctx-py/src/pctx_client/_client.py b/pctx-py/src/pctx_client/_client.py index 373042a..0437a73 100644 --- a/pctx-py/src/pctx_client/_client.py +++ b/pctx-py/src/pctx_client/_client.py @@ -57,6 +57,7 @@ def __init__( tools: list[Tool | AsyncTool] | None = None, servers: list[ServerConfig] | None = None, url: str = "http://localhost:8080", + api_key: str | None = None, execute_timeout: float = 30.0, ): """ @@ -90,9 +91,15 @@ def __init__( ws_scheme = "wss" if http_scheme == "https" else "ws" - self._ws_client = WebSocketClient(url=f"{ws_scheme}://{host}/ws", tools=tools) - self._client = AsyncClient(base_url=f"{http_scheme}://{host}") + self._ws_client = WebSocketClient( + url=f"{ws_scheme}://{host}{parsed.path}/ws", api_key=api_key, tools=tools + ) + self._client = AsyncClient( + base_url=f"{http_scheme}://{host}{parsed.path}", + headers={"x-pctx-api-key": api_key or ""}, + ) self._session_id: str | None = None + self._api_key = api_key self._tools = tools or [] self._servers = servers or [] @@ -139,8 +146,7 @@ async def connect(self): f"Received invalid response from PCTX server at {self._client.base_url}. " "The server may be running but not responding correctly." ) from e - - self._client.headers = {"x-code-mode-session": self._session_id or ""} + self._client.headers.update({"x-code-mode-session": self._session_id or ""}) # Register all local tools & MCP servers configs: list[ToolConfig] = [ @@ -154,8 +160,10 @@ async def connect(self): for t in self._tools ] - await self._register_tools(configs) - await self._register_servers(self._servers) + if len(configs) > 0: + await self._register_tools(configs) + if len(self._servers) > 0: + await self._register_servers(self._servers) # reset search to re-index self._search_retriever = None diff --git a/pctx-py/src/pctx_client/_websocket_client.py b/pctx-py/src/pctx_client/_websocket_client.py index 04c4346..d03be82 100644 --- a/pctx-py/src/pctx_client/_websocket_client.py +++ b/pctx-py/src/pctx_client/_websocket_client.py @@ -47,7 +47,12 @@ class WebSocketClient: receive and handle tool execution requests from the server """ - def __init__(self, url: str, tools: list[Tool | AsyncTool] | None = None): + def __init__( + self, + url: str, + api_key: str | None = None, + tools: list[Tool | AsyncTool] | None = None, + ): """ Initialize the WebSocket client. @@ -57,6 +62,7 @@ def __init__(self, url: str, tools: list[Tool | AsyncTool] | None = None): self.url = url self.ws: ClientConnection | None = None self.tools = tools or [] + self._api_key = api_key self._pending_executions: dict[str | int, asyncio.Future] = {} self._request_counter = 0 @@ -68,7 +74,10 @@ async def _connect(self, code_mode_session: str): ConnectionError: If connection fails """ try: - headers = {"x-code-mode-session": code_mode_session} + headers = { + "x-code-mode-session": code_mode_session, + "x-pctx-api-key": self._api_key, + } self.ws = await websockets.connect(self.url, additional_headers=headers) except Exception as e: raise ConnectionError(f"Failed to connect to {self.url}: {e}") from e diff --git a/pctx-py/tests/scripts/manual_code_mode.py b/pctx-py/tests/scripts/manual_code_mode.py index bc2b0ec..ade49ae 100755 --- a/pctx-py/tests/scripts/manual_code_mode.py +++ b/pctx-py/tests/scripts/manual_code_mode.py @@ -30,7 +30,9 @@ def multiply(a: float, b: float) -> MultiplyOutput: async def main(): - p = Pctx( + async with Pctx( + url="http://localhost:8080/some-org/some-server", + api_key="asdlkfjasldf", tools=[add, subtract, multiply], # servers=[ # { @@ -42,32 +44,37 @@ async def main(): # }, # } # ], - ) - print("connecting....") - await p.connect() + ) as p: + print("+++++++++++ LIST +++++++++++\n") + print((await p.list_functions()).code) - print("+++++++++++ LIST +++++++++++\n") - print((await p.list_functions()).code) + print("\n\n+++++++++++ DETAILS +++++++++++\n") + print((await p.get_function_details(["MyMath.add"])).code) - print("\n\n+++++++++++ DETAILS +++++++++++\n") - print((await p.get_function_details(["MyMath.add"])).code) + code = """ + async function run() { + let addval = await MyMath.add({a: 40, b: 2}); + let subval = await MyMath.subtract({a: addval, b: 2}); + let multval = await MyMath.multiply({a: subval, b: 2}); - code = """ -async function run() { - let addval = await MyMath.add({a: 40, b: 2}); - let subval = await MyMath.subtract({a: addval, b: 2}); - let multval = await MyMath.multiply({a: subval, b: 2}); + return multval; + } + """ + output = await p.execute(code) + pprint.pprint(output) - return multval; -} -""" - print(code) - output = await p.execute(code) - pprint.pprint(output) + invalid_code = """ + async function run() { + let addval = await MyMath.add({a: "40", b: 2}); // invalid because `a` must be a number - print("disconnecting....") - await p.disconnect() + return addval; + } + """ + invalid_output = await p.execute(invalid_code) + pprint.pprint(invalid_output) + + print(p._session_id) if __name__ == "__main__":