Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/some-buses-run.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"agents": patch
---

Add generic type parameters to MCP handler functions for better type safety
21 changes: 11 additions & 10 deletions packages/agents/src/mcp/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ export interface CreateMcpHandlerOptions extends WorkerTransportOptions {
transport?: WorkerTransport;
}

export function createMcpHandler(
export function createMcpHandler<Env = unknown, Props = unknown>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I preferred this to default to Cloudflare.Env, why was it necessary to change it back to unknown? trouble with typescript?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with unknown thinking it would be better for type safety and was trying to make it more flexible. So you didn’t need to use Cloudflare.Env and was thinking they’d be unknown until the user specifies. What is your preference? Looking back now I can see why you’d specify in a cloudflare repo.

server: McpServer | Server,
options: CreateMcpHandlerOptions = {}
): (
request: Request,
env: unknown,
ctx: ExecutionContext
env: Env,
ctx: ExecutionContext<Props>
) => Promise<Response> {
const route = options.route ?? "/mcp";

return async (
request: Request,
_env: unknown,
ctx: ExecutionContext
_env: Env,
ctx: ExecutionContext<Props>
): Promise<Response> => {
const url = new URL(request.url);
if (route && url.pathname !== route) {
Expand All @@ -60,9 +60,10 @@ export function createMcpHandler(
return options.authContext;
}

if (ctx.props && Object.keys(ctx.props).length > 0) {
const props = ctx.props as Record<string, unknown> | undefined;
if (props && Object.keys(props).length > 0) {
return {
props: ctx.props as Record<string, unknown>
props
};
}

Expand Down Expand Up @@ -109,13 +110,13 @@ let didWarnAboutExperimentalCreateMcpHandler = false;
/**
* @deprecated This has been renamed to createMcpHandler, and experimental_createMcpHandler will be removed in the next major version
*/
export function experimental_createMcpHandler(
export function experimental_createMcpHandler<Env = unknown, Props = unknown>(
server: McpServer | Server,
options: CreateMcpHandlerOptions = {}
): (
request: Request,
env: unknown,
ctx: ExecutionContext
env: Env,
ctx: ExecutionContext<Props>
) => Promise<Response> {
if (!didWarnAboutExperimentalCreateMcpHandler) {
didWarnAboutExperimentalCreateMcpHandler = true;
Expand Down
6 changes: 3 additions & 3 deletions packages/agents/src/mcp/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ export abstract class McpAgent<
/** Return a handler for the given path for this MCP.
* Defaults to Streamable HTTP transport.
*/
static serve(
static serve<Env = unknown, Props = unknown>(
path: string,
{
binding = "MCP_OBJECT",
Expand All @@ -371,11 +371,11 @@ export abstract class McpAgent<
}: ServeOptions = {}
) {
return {
async fetch<Env>(
async fetch(
this: void,
request: Request,
env: Env,
ctx: ExecutionContext
ctx: ExecutionContext<Props>
): Promise<Response> {
// Handle CORS preflight
const corsResponse = handleCORS(request, corsOptions);
Expand Down
29 changes: 16 additions & 13 deletions packages/agents/src/tests/mcp/handler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ declare module "cloudflare:test" {
interface ProvidedEnv {}
}

const createTestExecutionContext = () =>
createExecutionContext() as ExecutionContext<Record<string, unknown>>;

/**
* Tests for createMcpHandler
* The handler primarily passes options to WorkerTransport and handles routing
Expand Down Expand Up @@ -42,7 +45,7 @@ describe("createMcpHandler", () => {
route: "/custom-mcp"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();

// Request to non-matching route
const wrongRequest = new Request("http://example.com/mcp", {
Expand All @@ -63,7 +66,7 @@ describe("createMcpHandler", () => {
const server = createTestServer();
const handler = createMcpHandler(server);

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "OPTIONS"
});
Expand All @@ -85,7 +88,7 @@ describe("createMcpHandler", () => {
}
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "OPTIONS"
});
Expand All @@ -109,7 +112,7 @@ describe("createMcpHandler", () => {
route: "/mcp"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -149,7 +152,7 @@ describe("createMcpHandler", () => {
transport: customTransport
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "OPTIONS"
});
Expand All @@ -176,7 +179,7 @@ describe("createMcpHandler", () => {
transport: customTransport
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "OPTIONS"
});
Expand All @@ -203,7 +206,7 @@ describe("createMcpHandler", () => {
sessionIdGenerator: customSessionIdGenerator
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -241,7 +244,7 @@ describe("createMcpHandler", () => {
}
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -274,7 +277,7 @@ describe("createMcpHandler", () => {
enableJsonResponse: true
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -311,7 +314,7 @@ describe("createMcpHandler", () => {
storage: mockStorage
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -343,7 +346,7 @@ describe("createMcpHandler", () => {
corsOptions: { origin: "https://example.com" }
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/custom-route", {
method: "OPTIONS"
});
Expand Down Expand Up @@ -373,7 +376,7 @@ describe("createMcpHandler", () => {
transport: errorTransport
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -418,7 +421,7 @@ describe("createMcpHandler", () => {
transport: errorTransport
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down
17 changes: 10 additions & 7 deletions packages/agents/src/tests/mcp/jurisdiction.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ declare module "cloudflare:test" {
}
}

const createTestExecutionContext = () =>
createExecutionContext() as ExecutionContext<Record<string, unknown>>;

/**
* Tests for jurisdiction option in McpAgent.serve()
*
Expand Down Expand Up @@ -61,7 +64,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -95,7 +98,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "sse"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "GET"
});
Expand All @@ -115,7 +118,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -147,7 +150,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -183,7 +186,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();

// First request (initialization)
const initRequest = new Request("http://example.com/mcp", {
Expand Down Expand Up @@ -241,7 +244,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "sse"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();

// First, establish SSE connection
const sseRequest = new Request(
Expand Down Expand Up @@ -305,7 +308,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down
6 changes: 4 additions & 2 deletions packages/agents/src/tests/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1419,12 +1419,14 @@ export default {
testValue: "123"
};

const typedCtx = ctx as ExecutionContext<Record<string, unknown>>;

if (url.pathname === "/sse" || url.pathname === "/sse/message") {
return TestMcpAgent.serveSSE("/sse").fetch(request, env, ctx);
return TestMcpAgent.serveSSE("/sse").fetch(request, env, typedCtx);
}

if (url.pathname === "/mcp") {
return TestMcpAgent.serve("/mcp").fetch(request, env, ctx);
return TestMcpAgent.serve("/mcp").fetch(request, env, typedCtx);
}

if (url.pathname === "/500") {
Expand Down
Loading