diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 90c6116e0..26f1868a6 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -60,6 +60,17 @@ import type { Transport, TransportSendOptions } from './transport.js'; */ export type ProgressCallback = (progress: Progress) => void; +/** + * A JSON-RPC message that can be transformed by middleware. + */ +export type JSONRPCMessageLike = JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCResultResponse | JSONRPCErrorResponse; + +/** + * Middleware function for transforming JSON-RPC messages. + * Can be sync (returns message) or async (returns Promise). + */ +export type MessageMiddleware = (message: JSONRPCMessageLike) => JSONRPCMessageLike | Promise; + /** * Additional initialization options. */ @@ -102,6 +113,16 @@ export type ProtocolOptions = { * appropriately (e.g., by failing the task, dropping messages, etc.). */ maxTaskQueueSize?: number; + /** + * Middleware functions to apply to outgoing messages before sending. + * Middleware is applied in order, with each function receiving the output of the previous. + */ + sendMiddleware?: MessageMiddleware[]; + /** + * Middleware functions to apply to incoming messages after receiving. + * Middleware is applied in order, with each function receiving the output of the previous. + */ + receiveMiddleware?: MessageMiddleware[]; }; /** @@ -603,6 +624,22 @@ export abstract class Protocol(message: T, middlewareList: MessageMiddleware[] | undefined): Promise { + if (!middlewareList || middlewareList.length === 0) { + return message; + } + + let result: JSONRPCMessageLike = message; + for (const middleware of middlewareList) { + result = await middleware(result); + } + return result as T; + } + /** * Attaches to the given transport, starts it, and starts listening for messages. * @@ -625,14 +662,26 @@ export abstract class Protocol { _onmessage?.(message, extra); - if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { - this._onresponse(message); - } else if (isJSONRPCRequest(message)) { - this._onrequest(message, extra); - } else if (isJSONRPCNotification(message)) { - this._onnotification(message); + + // Route the message, applying receive middleware if configured + const routeMessage = (msg: JSONRPCMessageLike) => { + if (isJSONRPCResultResponse(msg) || isJSONRPCErrorResponse(msg)) { + this._onresponse(msg); + } else if (isJSONRPCRequest(msg)) { + this._onrequest(msg, extra); + } else if (isJSONRPCNotification(msg)) { + this._onnotification(msg); + } else { + this._onerror(new Error(`Unknown message type: ${JSON.stringify(msg)}`)); + } + }; + + if (this._options?.receiveMiddleware && this._options.receiveMiddleware.length > 0) { + this._applyMiddleware(message, this._options.receiveMiddleware) + .then(routeMessage) + .catch(error => this._onerror(new Error(`Receive middleware error: ${error}`))); } else { - this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`)); + routeMessage(message); } }; @@ -1074,9 +1123,52 @@ export abstract class Protocol(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { + async request(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; + // Build the JSON-RPC request + const messageId = this._requestMessageId++; + let jsonrpcRequest: JSONRPCRequest = { + ...request, + jsonrpc: '2.0', + id: messageId + }; + + if (options?.onprogress) { + this._progressHandlers.set(messageId, options.onprogress); + jsonrpcRequest.params = { + ...request.params, + _meta: { + ...(request.params?._meta || {}), + progressToken: messageId + } + }; + } + + // Augment with task creation parameters if provided + if (task) { + jsonrpcRequest.params = { + ...jsonrpcRequest.params, + task: task + }; + } + + // Augment with related task metadata if relatedTask is provided + if (relatedTask) { + jsonrpcRequest.params = { + ...jsonrpcRequest.params, + _meta: { + ...(jsonrpcRequest.params?._meta || {}), + [RELATED_TASK_META_KEY]: relatedTask + } + }; + } + + // Apply send middleware before sending (only await if there's middleware) + if (this._options?.sendMiddleware && this._options.sendMiddleware.length > 0) { + jsonrpcRequest = (await this._applyMiddleware(jsonrpcRequest, this._options.sendMiddleware)) as JSONRPCRequest; + } + // Send the request return new Promise>((resolve, reject) => { const earlyReject = (error: unknown) => { @@ -1104,43 +1196,6 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); @@ -1224,7 +1279,7 @@ export abstract class Protocol { + this._transport!.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { this._cleanupTimeout(messageId); reject(error); }); @@ -1302,9 +1357,18 @@ export abstract class Protocol 0) { + notificationToQueue = (await this._applyMiddleware( + jsonrpcNotification, + this._options.sendMiddleware + )) as JSONRPCNotification; + } + await this._enqueueTaskMessage(relatedTaskId, { type: 'notification', - message: jsonrpcNotification, + message: notificationToQueue, timestamp: Date.now() }); @@ -1358,9 +1422,19 @@ export abstract class Protocol this._onerror(error)); + if (this._options?.sendMiddleware && this._options.sendMiddleware.length > 0) { + this._applyMiddleware(jsonrpcNotification, this._options.sendMiddleware) + .then(transformedNotification => { + this._transport + ?.send(transformedNotification as JSONRPCNotification, options) + .catch(error => this._onerror(error)); + }) + .catch(error => this._onerror(error)); + } else { + this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error)); + } }); // Return immediately. @@ -1386,7 +1460,16 @@ export abstract class Protocol 0) { + const transformedNotification = (await this._applyMiddleware( + jsonrpcNotification, + this._options.sendMiddleware + )) as JSONRPCNotification; + await this._transport.send(transformedNotification, options); + } else { + await this._transport.send(jsonrpcNotification, options); + } } /**