Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 133 additions & 50 deletions packages/core/src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ import type { Transport, TransportSendOptions } from './transport.js';
*/
export type ProgressCallback = (progress: Progress) => void;

/**
* A JSON-RPC message that can be transformed by middleware.
*/
export type JSONRPCMessageLike = JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCResultResponse | JSONRPCErrorResponse;

/**
* Middleware function for transforming JSON-RPC messages.
* Can be sync (returns message) or async (returns Promise<message>).
*/
export type MessageMiddleware = (message: JSONRPCMessageLike) => JSONRPCMessageLike | Promise<JSONRPCMessageLike>;

/**
* Additional initialization options.
*/
Expand Down Expand Up @@ -102,6 +113,16 @@ export type ProtocolOptions = {
* appropriately (e.g., by failing the task, dropping messages, etc.).
*/
maxTaskQueueSize?: number;
/**
* Middleware functions to apply to outgoing messages before sending.
* Middleware is applied in order, with each function receiving the output of the previous.
*/
sendMiddleware?: MessageMiddleware[];
/**
* Middleware functions to apply to incoming messages after receiving.
* Middleware is applied in order, with each function receiving the output of the previous.
*/
receiveMiddleware?: MessageMiddleware[];
};

/**
Expand Down Expand Up @@ -603,6 +624,22 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
}
}

/**
* Applies a list of middleware functions to a message.
* Middleware is applied in order, with each function receiving the output of the previous.
*/
private async _applyMiddleware<T extends JSONRPCMessageLike>(message: T, middlewareList: MessageMiddleware[] | undefined): Promise<T> {
if (!middlewareList || middlewareList.length === 0) {
return message;
}

let result: JSONRPCMessageLike = message;
for (const middleware of middlewareList) {
result = await middleware(result);
}
return result as T;
}

/**
* Attaches to the given transport, starts it, and starts listening for messages.
*
Expand All @@ -625,14 +662,26 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
const _onmessage = this._transport?.onmessage;
this._transport.onmessage = (message, extra) => {
_onmessage?.(message, extra);
if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) {
this._onresponse(message);
} else if (isJSONRPCRequest(message)) {
this._onrequest(message, extra);
} else if (isJSONRPCNotification(message)) {
this._onnotification(message);

// Route the message, applying receive middleware if configured
const routeMessage = (msg: JSONRPCMessageLike) => {
if (isJSONRPCResultResponse(msg) || isJSONRPCErrorResponse(msg)) {
this._onresponse(msg);
} else if (isJSONRPCRequest(msg)) {
this._onrequest(msg, extra);
} else if (isJSONRPCNotification(msg)) {
this._onnotification(msg);
} else {
this._onerror(new Error(`Unknown message type: ${JSON.stringify(msg)}`));
}
};

if (this._options?.receiveMiddleware && this._options.receiveMiddleware.length > 0) {
this._applyMiddleware(message, this._options.receiveMiddleware)
.then(routeMessage)
.catch(error => this._onerror(new Error(`Receive middleware error: ${error}`)));
} else {
this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`));
routeMessage(message);
}
};

Expand Down Expand Up @@ -1074,9 +1123,52 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
*
* Do not use this method to emit notifications! Use notification() instead.
*/
request<T extends AnySchema>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise<SchemaOutput<T>> {
async request<T extends AnySchema>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise<SchemaOutput<T>> {
const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {};

// Build the JSON-RPC request
const messageId = this._requestMessageId++;
let jsonrpcRequest: JSONRPCRequest = {
...request,
jsonrpc: '2.0',
id: messageId
};

if (options?.onprogress) {
this._progressHandlers.set(messageId, options.onprogress);
jsonrpcRequest.params = {
...request.params,
_meta: {
...(request.params?._meta || {}),
progressToken: messageId
}
};
}

// Augment with task creation parameters if provided
if (task) {
jsonrpcRequest.params = {
...jsonrpcRequest.params,
task: task
};
}

// Augment with related task metadata if relatedTask is provided
if (relatedTask) {
jsonrpcRequest.params = {
...jsonrpcRequest.params,
_meta: {
...(jsonrpcRequest.params?._meta || {}),
[RELATED_TASK_META_KEY]: relatedTask
}
};
}

// Apply send middleware before sending (only await if there's middleware)
if (this._options?.sendMiddleware && this._options.sendMiddleware.length > 0) {
jsonrpcRequest = (await this._applyMiddleware(jsonrpcRequest, this._options.sendMiddleware)) as JSONRPCRequest;
}

// Send the request
return new Promise<SchemaOutput<T>>((resolve, reject) => {
const earlyReject = (error: unknown) => {
Expand Down Expand Up @@ -1104,43 +1196,6 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e

options?.signal?.throwIfAborted();

const messageId = this._requestMessageId++;
const jsonrpcRequest: JSONRPCRequest = {
...request,
jsonrpc: '2.0',
id: messageId
};

if (options?.onprogress) {
this._progressHandlers.set(messageId, options.onprogress);
jsonrpcRequest.params = {
...request.params,
_meta: {
...(request.params?._meta || {}),
progressToken: messageId
}
};
}

// Augment with task creation parameters if provided
if (task) {
jsonrpcRequest.params = {
...jsonrpcRequest.params,
task: task
};
}

// Augment with related task metadata if relatedTask is provided
if (relatedTask) {
jsonrpcRequest.params = {
...jsonrpcRequest.params,
_meta: {
...(jsonrpcRequest.params?._meta || {}),
[RELATED_TASK_META_KEY]: relatedTask
}
};
}

const cancel = (reason: unknown) => {
this._responseHandlers.delete(messageId);
this._progressHandlers.delete(messageId);
Expand Down Expand Up @@ -1224,7 +1279,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
// This prevents duplicate delivery for bidirectional transports
} else {
// No related task - send through transport normally
this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => {
this._transport!.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => {
this._cleanupTimeout(messageId);
reject(error);
});
Expand Down Expand Up @@ -1302,9 +1357,18 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
}
};

// Apply send middleware before queuing (only if there's middleware)
let notificationToQueue: JSONRPCNotification = jsonrpcNotification;
if (this._options?.sendMiddleware && this._options.sendMiddleware.length > 0) {
notificationToQueue = (await this._applyMiddleware(
jsonrpcNotification,
this._options.sendMiddleware
)) as JSONRPCNotification;
}

await this._enqueueTaskMessage(relatedTaskId, {
type: 'notification',
message: jsonrpcNotification,
message: notificationToQueue,
timestamp: Date.now()
});

Expand Down Expand Up @@ -1358,9 +1422,19 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
};
}

// Send the notification, but don't await it here to avoid blocking.
// Apply send middleware and send the notification.
// Handle potential errors with a .catch().
this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error));
if (this._options?.sendMiddleware && this._options.sendMiddleware.length > 0) {
this._applyMiddleware(jsonrpcNotification, this._options.sendMiddleware)
.then(transformedNotification => {
this._transport
?.send(transformedNotification as JSONRPCNotification, options)
.catch(error => this._onerror(error));
})
.catch(error => this._onerror(error));
} else {
this._transport?.send(jsonrpcNotification, options).catch(error => this._onerror(error));
}
});

// Return immediately.
Expand All @@ -1386,7 +1460,16 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
};
}

await this._transport.send(jsonrpcNotification, options);
// Apply send middleware before sending (only if there's middleware)
if (this._options?.sendMiddleware && this._options.sendMiddleware.length > 0) {
const transformedNotification = (await this._applyMiddleware(
jsonrpcNotification,
this._options.sendMiddleware
)) as JSONRPCNotification;
await this._transport.send(transformedNotification, options);
} else {
await this._transport.send(jsonrpcNotification, options);
}
}

/**
Expand Down
Loading