From 51239a48fab00d0fa2c2763b57fab71b6054e9b9 Mon Sep 17 00:00:00 2001 From: icebear0828 Date: Wed, 18 Mar 2026 01:39:26 -0500 Subject: [PATCH] fix: retry callTool on transient network errors + code quality fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Network retry: - Retry callTool once on transient network errors for idempotent (read) ops - Use RETRYABLE_TOOLS allowlist — unknown tools default to NOT retrying - Close old transport before reconnecting to prevent resource leaks - Preserve original error when retry also fails Type safety (remove all `as any` from src/ and test/): - Import CallToolResult/CompatibilityCallToolResult from MCP SDK - Add asCallToolResult() to handle legacy protocol format safely - Replace Record with Record - Replace `{} as any` Proxy target with concrete type in singleton.ts - Type all test mocks properly (MockFetch, ProxyContext helpers, etc.) Proxy improvements: - Replace Date.now() JSON-RPC ID with auto-incrementing counter - Await notifications/initialized before refreshTools (fix race) - Add OAuth support (accessToken + projectId in config and headers) - Add config validation: require either apiKey or accessToken - Extract buildProxyAuthHeaders() for consistent auth header construction Other: - Add @ai-sdk/google devDependency (fixes tsc error in model-helpers.ts) - Add missing protocolVersion in capture-tools.ts - Remove duplicate `dist` entry in .gitignore --- .gitignore | 1 - package-lock.json | 92 +++++++-- package.json | 3 +- packages/sdk/src/client.ts | 101 ++++++++-- packages/sdk/src/proxy/client.ts | 62 ++++-- packages/sdk/src/proxy/core.ts | 6 +- packages/sdk/src/singleton.ts | 2 +- packages/sdk/src/spec/proxy.ts | 15 +- packages/sdk/src/tools-adapter.ts | 2 +- packages/sdk/test/proxy.test.ts | 202 +++++++++++-------- packages/sdk/test/unit/client.test.ts | 148 +++++++++++++- packages/sdk/test/unit/sdk.test.ts | 2 +- packages/sdk/test/unit/tools-adapter.test.ts | 11 +- scripts/capture-tools.ts | 1 + 14 files changed, 487 insertions(+), 161 deletions(-) diff --git a/.gitignore b/.gitignore index 3984b1a..58e4be4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ coverage .DS_Store .env tsconfig.tsbuildinfo -dist .firebaserc .firebase .vscode diff --git a/package-lock.json b/package-lock.json index 8c767e1..fee31ae 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,6 +11,7 @@ "packages/*" ], "devDependencies": { + "@ai-sdk/google": "^3.0.43", "@types/bun": "^1.3.10", "@types/react": "^18.3.0", "firebase-tools": "^14.26.0", @@ -1651,7 +1652,6 @@ "os": [ "darwin" ], - "peer": true, "engines": { "node": ">=10" } @@ -1669,7 +1669,6 @@ "os": [ "darwin" ], - "peer": true, "engines": { "node": ">=10" } @@ -1687,7 +1686,6 @@ "os": [ "linux" ], - "peer": true, "engines": { "node": ">=10" } @@ -1705,7 +1703,6 @@ "os": [ "linux" ], - "peer": true, "engines": { "node": ">=10" } @@ -1723,7 +1720,6 @@ "os": [ "linux" ], - "peer": true, "engines": { "node": ">=10" } @@ -1741,7 +1737,6 @@ "os": [ "linux" ], - "peer": true, "engines": { "node": ">=10" } @@ -1759,7 +1754,6 @@ "os": [ "linux" ], - "peer": true, "engines": { "node": ">=10" } @@ -1777,7 +1771,6 @@ "os": [ "win32" ], - "peer": true, "engines": { "node": ">=10" } @@ -1795,7 +1788,6 @@ "os": [ "win32" ], - "peer": true, "engines": { "node": ">=10" } @@ -1813,7 +1805,6 @@ "os": [ "win32" ], - "peer": true, "engines": { "node": ">=10" } @@ -4150,13 +4141,6 @@ "bare-events": "^2.7.0" } }, - "node_modules/eventsource": { - "version": "2.0.2", - "license": "MIT", - "engines": { - "node": ">=12.0.0" - } - }, "node_modules/eventsource-parser": { "version": "3.0.6", "license": "MIT", @@ -9012,6 +8996,20 @@ "turbo-windows-arm64": "2.8.0" } }, + "node_modules/turbo-darwin-64": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/turbo-darwin-64/-/turbo-darwin-64-2.8.0.tgz", + "integrity": "sha512-N7f4PYqz25yk8c5kituk09bJ89tE4wPPqKXgYccT6nbEQnGnrdvlyCHLyqViNObTgjjrddqjb1hmDkv7VcxE0g==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, "node_modules/turbo-darwin-arm64": { "version": "2.8.0", "cpu": [ @@ -9024,6 +9022,62 @@ "darwin" ] }, + "node_modules/turbo-linux-64": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/turbo-linux-64/-/turbo-linux-64-2.8.0.tgz", + "integrity": "sha512-ILR45zviYae3icf4cmUISdj8X17ybNcMh3Ms4cRdJF5sS50qDDTv8qeWqO/lPeHsu6r43gVWDofbDZYVuXYL7Q==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/turbo-linux-arm64": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/turbo-linux-arm64/-/turbo-linux-arm64-2.8.0.tgz", + "integrity": "sha512-z9pUa8ENFuHmadPfjEYMRWlXO82t1F/XBDx2XTg+cWWRZHf85FnEB6D4ForJn/GoKEEvwdPhFLzvvhOssom2ug==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/turbo-windows-64": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/turbo-windows-64/-/turbo-windows-64-2.8.0.tgz", + "integrity": "sha512-J6juRSRjmSErEqJCv7nVIq2DgZ2NHXqyeV8NQTFSyIvrThKiWe7FDOO6oYpuR06+C1NW82aoN4qQt4/gYvz25w==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/turbo-windows-arm64": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/turbo-windows-arm64/-/turbo-windows-arm64-2.8.0.tgz", + "integrity": "sha512-qarBZvCu6uka35739TS+y/3CBU3zScrVAfohAkKHG+So+93Wn+5tKArs8HrO2fuTaGou8fMIeTV7V5NgzCVkSQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, "node_modules/type-detect": { "version": "4.1.0", "dev": true, @@ -9873,15 +9927,13 @@ }, "packages/sdk": { "name": "@google/stitch-sdk", - "version": "0.0.1", + "version": "0.0.3", "license": "Apache-2.0", "dependencies": { "@modelcontextprotocol/sdk": "^1.23.0", - "eventsource": "^2.0.2", "zod": "^4.3.5" }, "devDependencies": { - "@ai-sdk/google": "^3.0.43", "@swc/core": "^1.15.18", "@types/node": "^20.0.0", "ai": "^6.0.116", diff --git a/package.json b/package.json index d1c9d5a..0f2d460 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "validate:release": "bun scripts/validate-release.ts" }, "devDependencies": { + "@ai-sdk/google": "^3.0.43", "@types/bun": "^1.3.10", "@types/react": "^18.3.0", "firebase-tools": "^14.26.0", @@ -39,4 +40,4 @@ "typescript": "^5.5.0" }, "packageManager": "bun@1.3.1" -} \ No newline at end of file +} diff --git a/packages/sdk/src/client.ts b/packages/sdk/src/client.ts index e40e5d1..4cb3d26 100644 --- a/packages/sdk/src/client.ts +++ b/packages/sdk/src/client.ts @@ -14,6 +14,7 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import type { CallToolResult, CompatibilityCallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { StitchConfigSchema, StitchConfig, @@ -77,10 +78,21 @@ export class StitchToolClient implements StitchToolClientSpec { return headers; } - private parseToolResponse(result: any, name: string): T { + private asCallToolResult(result: CompatibilityCallToolResult): CallToolResult { + if (Array.isArray((result as CallToolResult).content)) { + return result as CallToolResult; + } + // Legacy protocol format — wrap toolResult as text content + const legacy = result as { toolResult: unknown }; + return { + content: [{ type: "text", text: JSON.stringify(legacy.toolResult ?? null) }], + }; + } + + private parseToolResponse(result: CallToolResult, name: string): T { if (result.isError) { - const errorText = (result.content as any[]) - .map((c: any) => (c.type === "text" ? c.text : "")) + const errorText = result.content + .map((c) => (c.type === "text" ? c.text : "")) .join(""); let code: StitchErrorCode = "UNKNOWN_ERROR"; @@ -118,12 +130,9 @@ export class StitchToolClient implements StitchToolClientSpec { } // Stitch specific parsing: Check structuredContent first, then JSON in text - const anyResult = result as any; - if (anyResult.structuredContent) return anyResult.structuredContent as T; + if (result.structuredContent) return result.structuredContent as T; - const textContent = (result.content as any[]).find( - (c: any) => c.type === "text", - ); + const textContent = result.content.find((c) => c.type === "text"); if (textContent && textContent.type === "text") { try { return JSON.parse(textContent.text) as T; @@ -132,7 +141,7 @@ export class StitchToolClient implements StitchToolClientSpec { } } - return anyResult as T; + return result as unknown as T; } async connect() { @@ -148,6 +157,11 @@ export class StitchToolClient implements StitchToolClientSpec { } private async doConnect() { + // Close existing transport before creating a new one to prevent resource leaks + if (this.transport) { + await this.transport.close().catch(() => {}); + } + // Create transport with auth headers injected per-instance (no global fetch mutation) this.transport = new StreamableHTTPClientTransport( new URL(this.config.baseUrl), @@ -167,19 +181,74 @@ export class StitchToolClient implements StitchToolClientSpec { this.isConnected = true; } + /** + * Tools that are safe to retry on network errors (idempotent read operations). + * Unknown tools default to NOT retrying — safer than the reverse. + */ + private static readonly RETRYABLE_TOOLS = new Set([ + "list_projects", + "get_project", + "list_screens", + "get_screen", + ]); + + /** + * Check if an error is a transient network failure (not an application error). + * Uses message substring matching — fragile across Node/Bun versions, + * but no reliable error codes exist for these transport-level failures. + */ + private isNetworkError(error: unknown): boolean { + if (error instanceof StitchError) return false; + const msg = + error instanceof Error ? error.message.toLowerCase() : String(error); + return ( + msg.includes("fetch failed") || + msg.includes("econnrefused") || + msg.includes("econnreset") || + msg.includes("etimedout") || + msg.includes("socket hang up") || + msg.includes("other side closed") + ); + } + /** * Generic tool caller with type support and error parsing. + * Retries once on transient network errors for idempotent (read) operations. + * Non-idempotent tools (generate, edit, create) are not retried. */ - async callTool(name: string, args: Record): Promise { + async callTool(name: string, args: Record): Promise { if (!this.isConnected) await this.connect(); - const result = await this.client.callTool( - { name, arguments: args }, - undefined, - { timeout: this.config.timeout }, - ); + try { + const result = await this.client.callTool( + { name, arguments: args }, + undefined, + { timeout: this.config.timeout }, + ); + return this.parseToolResponse(this.asCallToolResult(result), name); + } catch (error) { + if ( + !this.isNetworkError(error) || + !StitchToolClient.RETRYABLE_TOOLS.has(name) + ) { + throw error; + } + + // Reconnect and retry once for idempotent operations + this.isConnected = false; + await this.connect(); - return this.parseToolResponse(result, name); + try { + const result = await this.client.callTool( + { name, arguments: args }, + undefined, + { timeout: this.config.timeout }, + ); + return this.parseToolResponse(this.asCallToolResult(result), name); + } catch (_retryError: unknown) { + throw error; // throw the original error, not the retry error + } + } } async listTools() { diff --git a/packages/sdk/src/proxy/client.ts b/packages/sdk/src/proxy/client.ts index 70ab0ff..9ec8d4b 100644 --- a/packages/sdk/src/proxy/client.ts +++ b/packages/sdk/src/proxy/client.ts @@ -23,6 +23,29 @@ export interface ProxyContext { remoteTools: Tool[]; } +let nextRequestId = 1; + +/** + * Build auth headers based on proxy config (API key or OAuth). + */ +function buildProxyAuthHeaders(config: StitchProxyConfig): Record { + const headers: Record = { + 'Content-Type': 'application/json', + Accept: 'application/json', + }; + + if (config.apiKey) { + headers['X-Goog-Api-Key'] = config.apiKey; + } else if (config.accessToken) { + headers['Authorization'] = `Bearer ${config.accessToken}`; + if (config.projectId) { + headers['X-Goog-User-Project'] = config.projectId; + } + } + + return headers; +} + /** * Forward a JSON-RPC request to Stitch. */ @@ -35,22 +58,19 @@ export async function forwardToStitch( jsonrpc: '2.0', method, params: params ?? {}, - id: Date.now(), + id: nextRequestId++, }; let response: Response; try { response = await fetch(config.url, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - Accept: 'application/json', - 'X-Goog-Api-Key': config.apiKey!, - }, + headers: buildProxyAuthHeaders(config), body: JSON.stringify(request), }); - } catch (err: any) { - throw new Error(`Network failure connecting to Stitch API: ${err.message}`); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : String(err); + throw new Error(`Network failure connecting to Stitch API: ${msg}`); } if (!response.ok) { @@ -86,21 +106,19 @@ export async function initializeStitchConnection( }, }); - // Send initialized notification (fire and forget) - fetch(ctx.config.url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Accept: 'application/json', - 'X-Goog-Api-Key': ctx.config.apiKey!, - }, - body: JSON.stringify({ - jsonrpc: '2.0', - method: 'notifications/initialized', - }), - }).catch((err) => { + // Send initialized notification and await it before proceeding + try { + await fetch(ctx.config.url, { + method: 'POST', + headers: buildProxyAuthHeaders(ctx.config), + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'notifications/initialized', + }), + }); + } catch (err) { console.error('[stitch-proxy] Failed to send initialized notification:', err); - }); + } await refreshTools(ctx); console.error( diff --git a/packages/sdk/src/proxy/core.ts b/packages/sdk/src/proxy/core.ts index 72633be..a47fed7 100644 --- a/packages/sdk/src/proxy/core.ts +++ b/packages/sdk/src/proxy/core.ts @@ -32,6 +32,8 @@ export class StitchProxy implements StitchProxySpec { constructor(inputConfig?: Partial) { const rawConfig = { apiKey: inputConfig?.apiKey || process.env.STITCH_API_KEY, + accessToken: inputConfig?.accessToken || process.env.STITCH_ACCESS_TOKEN, + projectId: inputConfig?.projectId || process.env.GOOGLE_CLOUD_PROJECT, url: inputConfig?.url || process.env.STITCH_MCP_URL, name: inputConfig?.name, version: inputConfig?.version, @@ -40,10 +42,6 @@ export class StitchProxy implements StitchProxySpec { // Validate config this.config = StitchProxyConfigSchema.parse(rawConfig); - if (!this.config.apiKey) { - throw new Error('StitchProxy requires an API key (STITCH_API_KEY)'); - } - this.server = new McpServer( { name: this.config.name, diff --git a/packages/sdk/src/singleton.ts b/packages/sdk/src/singleton.ts index f292378..51ac8b4 100644 --- a/packages/sdk/src/singleton.ts +++ b/packages/sdk/src/singleton.ts @@ -75,7 +75,7 @@ const CLIENT_METHODS = new Set(["listTools", "callTool", "close"]); */ export const stitch = new Proxy< Stitch & Pick ->({} as any, { +>({} as Stitch & Pick, { get(_target, prop: string | symbol) { // Client methods → delegate to StitchToolClient if (typeof prop === "string" && CLIENT_METHODS.has(prop)) { diff --git a/packages/sdk/src/spec/proxy.ts b/packages/sdk/src/spec/proxy.ts index 7c1dbe3..d5b43b6 100644 --- a/packages/sdk/src/spec/proxy.ts +++ b/packages/sdk/src/spec/proxy.ts @@ -23,6 +23,12 @@ export const StitchProxyConfigSchema = z.object({ /** API key for Stitch authentication. Falls back to STITCH_API_KEY. */ apiKey: z.string().optional(), + /** OAuth access token for user-authenticated requests. Falls back to STITCH_ACCESS_TOKEN. */ + accessToken: z.string().optional(), + + /** Google Cloud project ID. Required for OAuth, optional for API Key. Falls back to GOOGLE_CLOUD_PROJECT. */ + projectId: z.string().optional(), + /** Target Stitch MCP URL. Default: https://stitch.googleapis.com/mcp */ url: z.string().default(DEFAULT_STITCH_API_URL), @@ -34,7 +40,14 @@ export const StitchProxyConfigSchema = z.object({ /** Protocol version to use for Stitch JSON-RPC connection. Default: '2024-11-05' */ protocolVersion: z.string().default('2024-11-05'), -}); +}).refine( + (data) => { + const hasApiKey = !!data.apiKey; + const hasOAuth = !!data.accessToken && !!data.projectId; + return hasApiKey || hasOAuth; + }, + { message: "Provide either 'apiKey' OR ('accessToken' + 'projectId') for authentication." } +); export type StitchProxyConfig = z.infer; diff --git a/packages/sdk/src/tools-adapter.ts b/packages/sdk/src/tools-adapter.ts index 46eb6f5..515c3ec 100644 --- a/packages/sdk/src/tools-adapter.ts +++ b/packages/sdk/src/tools-adapter.ts @@ -70,7 +70,7 @@ export function stitchTools(options?: { get jsonSchema() { return t.inputSchema; }, }, execute: async (args: unknown) => - client.callTool(t.name, args as Record), + client.callTool(t.name, args as Record), } as unknown as Tool, ]) ); diff --git a/packages/sdk/test/proxy.test.ts b/packages/sdk/test/proxy.test.ts index 7114ec3..64054d0 100644 --- a/packages/sdk/test/proxy.test.ts +++ b/packages/sdk/test/proxy.test.ts @@ -14,22 +14,44 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { StitchProxy } from '../src/proxy/index.js'; -import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; -import { forwardToStitch, initializeStitchConnection } from '../src/proxy/client.js'; +import type { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; +import type { Server } from "@modelcontextprotocol/sdk/server/index.js"; +import { forwardToStitch, initializeStitchConnection, type ProxyContext } from '../src/proxy/client.js'; import { registerListToolsHandler } from '../src/proxy/handlers/listTools.js'; import { registerCallToolHandler } from '../src/proxy/handlers/callTool.js'; import { ListToolsRequestSchema, CallToolRequestSchema } from '@modelcontextprotocol/sdk/types.js'; - - +import type { StitchProxyConfig } from '../src/spec/proxy.js'; // Mock fetch const globalFetch = global.fetch; +type MockFetch = ReturnType & typeof fetch; + +function makeProxyConfig(overrides?: Partial): StitchProxyConfig { + return { url: 'http://test', apiKey: 'test-key', name: 'stitch-proxy', version: '1.0.0', protocolVersion: '2024-11-05', ...overrides }; +} + +function makeProxyContext(overrides?: Partial): ProxyContext { + return { config: makeProxyConfig(), remoteTools: [], ...overrides }; +} + +type HandlerMap = Map Promise>; + +function makeMockServer() { + const handlers: HandlerMap = new Map(); + return { + handlers, + setRequestHandler(schema: unknown, handler: (...args: unknown[]) => Promise) { + handlers.set(schema, handler); + }, + }; +} + describe('StitchProxy', () => { - let mockFetch: any; + let mockFetch: MockFetch; beforeEach(() => { - mockFetch = vi.fn(); + mockFetch = vi.fn() as MockFetch; global.fetch = mockFetch; }); @@ -43,9 +65,22 @@ describe('StitchProxy', () => { expect(proxy).toBeDefined(); }); - it('should throw if no API key is provided', () => { + it('should throw if no auth is provided', () => { delete process.env.STITCH_API_KEY; - expect(() => new StitchProxy({})).toThrow("StitchProxy requires an API key"); + delete process.env.STITCH_ACCESS_TOKEN; + expect(() => new StitchProxy({})).toThrow(/apiKey.*OR.*accessToken.*projectId/i); + }); + + it('should throw if accessToken provided without projectId', () => { + delete process.env.STITCH_API_KEY; + delete process.env.GOOGLE_CLOUD_PROJECT; + expect(() => new StitchProxy({ accessToken: 'token-only' })).toThrow(/apiKey.*OR.*accessToken.*projectId/i); + }); + + it('should initialize with OAuth credentials', () => { + delete process.env.STITCH_API_KEY; + const proxy = new StitchProxy({ accessToken: 'ya29.token', projectId: 'my-project' }); + expect(proxy).toBeDefined(); }); it('should connect to stitch and fetch tools on start', async () => { @@ -60,36 +95,31 @@ describe('StitchProxy', () => { mockFetch.mockResolvedValueOnce({ ok: true, json: async () => ({}) - } as Response); // notifications/initialized (fire and forget, might not be awaited immediately but mocked anyway if called) + } as Response); // notifications/initialized mockFetch.mockResolvedValueOnce({ ok: true, json: async () => ({ result: { tools: [{ name: 'test-tool' }] } }) } as Response); // tools/list - const mockTransport = { + const mockTransport: Pick & Partial = { start: vi.fn().mockResolvedValue(undefined), close: vi.fn().mockResolvedValue(undefined), - onmessage: undefined, - onclose: undefined, - onerror: undefined, - send: vi.fn().mockResolvedValue(undefined) - } as unknown as Transport; + send: vi.fn().mockResolvedValue(undefined), + }; - await proxy.start(mockTransport); + await proxy.start(mockTransport as Transport); - // Expect 3 calls: initialize, notifications/initialized (which might complete quickly), and tools/list - // Since notifications/initialized is fire-and-forget but we mock fetch, it counts if called. expect(mockFetch).toHaveBeenCalledTimes(3); expect(mockTransport.start).toHaveBeenCalled(); }); }); describe('Proxy Client Error Handling', () => { - let mockFetch: any; + let mockFetch: MockFetch; beforeEach(() => { - mockFetch = vi.fn(); + mockFetch = vi.fn() as MockFetch; global.fetch = mockFetch; vi.spyOn(console, 'error').mockImplementation(() => {}); }); @@ -99,6 +129,47 @@ describe('Proxy Client Error Handling', () => { vi.clearAllMocks(); }); + it('forwardToStitch should send X-Goog-Api-Key header for API key auth', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ result: {} }) + } as Response); + + await forwardToStitch(makeProxyConfig({ apiKey: 'my-key' }), 'test'); + + const headers = mockFetch.mock.calls[0][1]?.headers as Record; + expect(headers['X-Goog-Api-Key']).toBe('my-key'); + expect(headers['Authorization']).toBeUndefined(); + }); + + it('forwardToStitch should send Bearer header for OAuth auth', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ result: {} }) + } as Response); + + await forwardToStitch(makeProxyConfig({ apiKey: undefined, accessToken: 'ya29.tok', projectId: 'proj-1' }), 'test'); + + const headers = mockFetch.mock.calls[0][1]?.headers as Record; + expect(headers['Authorization']).toBe('Bearer ya29.tok'); + expect(headers['X-Goog-User-Project']).toBe('proj-1'); + expect(headers['X-Goog-Api-Key']).toBeUndefined(); + }); + + it('forwardToStitch should use auto-incrementing IDs', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ result: {} }) + } as Response); + + await forwardToStitch(makeProxyConfig(), 'method1'); + await forwardToStitch(makeProxyConfig(), 'method2'); + + const body1 = JSON.parse(mockFetch.mock.calls[0][1]?.body as string); + const body2 = JSON.parse(mockFetch.mock.calls[1][1]?.body as string); + expect(body2.id).toBeGreaterThan(body1.id); + }); + it('forwardToStitch should throw Stitch API error on non-ok response', async () => { mockFetch.mockResolvedValue({ ok: false, @@ -106,7 +177,7 @@ describe('Proxy Client Error Handling', () => { text: async () => 'Internal Server Error' } as Response); - await expect(forwardToStitch({ url: 'http://test', apiKey: 'test-key' } as any, 'testMethod')).rejects.toThrow('Stitch API error (500): Internal Server Error'); + await expect(forwardToStitch(makeProxyConfig(), 'testMethod')).rejects.toThrow('Stitch API error (500): Internal Server Error'); }); it('forwardToStitch should throw Stitch RPC error on JSON-RPC error payload', async () => { @@ -115,7 +186,7 @@ describe('Proxy Client Error Handling', () => { json: async () => ({ error: { message: 'Method not found' } }) } as Response); - await expect(forwardToStitch({ url: 'http://test', apiKey: 'test-key' } as any, 'testMethod')).rejects.toThrow('Stitch RPC error: Method not found'); + await expect(forwardToStitch(makeProxyConfig(), 'testMethod')).rejects.toThrow('Stitch RPC error: Method not found'); }); it('initializeStitchConnection should catch and log rejected fetch on notifications/initialized', async () => { @@ -134,16 +205,9 @@ describe('Proxy Client Error Handling', () => { json: async () => ({ result: { tools: [] } }) } as Response); - const ctx = { - config: { url: 'http://test', apiKey: 'test-key', name: 'test', version: '1.0' }, - remoteTools: [] - } as any; - + const ctx = makeProxyContext(); await expect(initializeStitchConnection(ctx)).resolves.not.toThrow(); - // allow the fire-and-forget promise to settle - await new Promise(resolve => setTimeout(resolve, 0)); - expect(console.error).toHaveBeenCalledWith( '[stitch-proxy] Failed to send initialized notification:', expect.any(Error) @@ -152,21 +216,14 @@ describe('Proxy Client Error Handling', () => { }); describe('Proxy Handlers', () => { - let mockFetch: any; - let mockServer: any; + let mockFetch: MockFetch; + let mockServer: ReturnType; beforeEach(() => { - mockFetch = vi.fn(); + mockFetch = vi.fn() as MockFetch; global.fetch = mockFetch; vi.spyOn(console, 'error').mockImplementation(() => {}); - - // Mock for Server.setRequestHandler - mockServer = { - handlers: new Map(), - setRequestHandler(schema: any, handler: any) { - this.handlers.set(schema, handler); - } - }; + mockServer = makeMockServer(); }); afterEach(() => { @@ -175,14 +232,10 @@ describe('Proxy Handlers', () => { }); it('registerListToolsHandler should invoke refreshTools and return cached tools', async () => { - const ctx = { - config: { url: 'http://test', apiKey: 'test-key' }, - remoteTools: [] - } as any; - - registerListToolsHandler(mockServer as any, ctx); + const ctx = makeProxyContext(); + registerListToolsHandler(mockServer as unknown as Server, ctx); - const handler = mockServer.handlers.get(ListToolsRequestSchema); + const handler = mockServer.handlers.get(ListToolsRequestSchema)!; expect(handler).toBeDefined(); mockFetch.mockResolvedValueOnce({ @@ -190,29 +243,22 @@ describe('Proxy Handlers', () => { json: async () => ({ result: { tools: [{ name: 'refreshed-tool' }] } }) } as Response); - const result = await handler({} as any, {} as any); - + const result = await handler({}, {}); expect(result).toEqual({ tools: [{ name: 'refreshed-tool' }] }); expect(ctx.remoteTools).toEqual([{ name: 'refreshed-tool' }]); }); it('registerListToolsHandler should handle fetch error gracefully', async () => { - const ctx = { - config: { url: 'http://test', apiKey: 'test-key' }, - remoteTools: [{ name: 'existing-tool' }] - } as any; - - registerListToolsHandler(mockServer as any, ctx); + const ctx = makeProxyContext({ remoteTools: [{ name: 'existing-tool', inputSchema: { type: 'object' as const } }] }); + registerListToolsHandler(mockServer as unknown as Server, ctx); - const handler = mockServer.handlers.get(ListToolsRequestSchema); + const handler = mockServer.handlers.get(ListToolsRequestSchema)!; expect(handler).toBeDefined(); mockFetch.mockRejectedValueOnce(new Error('Network failure')); - const result = await handler({} as any, {} as any); - - // Should return existing tools if refresh fails - expect(result).toEqual({ tools: [{ name: 'existing-tool' }] }); + const result = await handler({}, {}); + expect(result).toEqual({ tools: [{ name: 'existing-tool', inputSchema: { type: 'object' } }] }); expect(console.error).toHaveBeenCalledWith( '[stitch-proxy] Failed to refresh tools:', expect.any(Error) @@ -220,14 +266,10 @@ describe('Proxy Handlers', () => { }); it('registerCallToolHandler should invoke forwardToStitch and return result', async () => { - const ctx = { - config: { url: 'http://test', apiKey: 'test-key' }, - remoteTools: [] - } as any; + const ctx = makeProxyContext(); + registerCallToolHandler(mockServer as unknown as Server, ctx); - registerCallToolHandler(mockServer as any, ctx); - - const handler = mockServer.handlers.get(CallToolRequestSchema); + const handler = mockServer.handlers.get(CallToolRequestSchema)!; expect(handler).toBeDefined(); mockFetch.mockResolvedValueOnce({ @@ -235,34 +277,24 @@ describe('Proxy Handlers', () => { json: async () => ({ result: { content: [{ type: 'text', text: 'success' }] } }) } as Response); - const request = { - params: { name: 'test_tool', arguments: { arg1: 'value1' } } - }; - - const result = await handler(request as any, {} as any); + const request = { params: { name: 'test_tool', arguments: { arg1: 'value1' } } }; + const result = await handler(request, {}); expect(result).toEqual({ content: [{ type: 'text', text: 'success' }] }); expect(console.error).toHaveBeenCalledWith('[stitch-proxy] Calling tool: test_tool'); }); it('registerCallToolHandler should return isError: true on failure', async () => { - const ctx = { - config: { url: 'http://test', apiKey: 'test-key' }, - remoteTools: [] - } as any; + const ctx = makeProxyContext(); + registerCallToolHandler(mockServer as unknown as Server, ctx); - registerCallToolHandler(mockServer as any, ctx); - - const handler = mockServer.handlers.get(CallToolRequestSchema); + const handler = mockServer.handlers.get(CallToolRequestSchema)!; expect(handler).toBeDefined(); mockFetch.mockRejectedValueOnce(new Error('RPC failed')); - const request = { - params: { name: 'test_tool', arguments: { arg1: 'value1' } } - }; - - const result = await handler(request as any, {} as any); + const request = { params: { name: 'test_tool', arguments: { arg1: 'value1' } } }; + const result = await handler(request, {}) as { isError: boolean; content: Array<{ type: string; text: string }> }; expect(result.isError).toBe(true); expect(result.content[0].type).toBe('text'); diff --git a/packages/sdk/test/unit/client.test.ts b/packages/sdk/test/unit/client.test.ts index fa34b0f..1e25a5c 100644 --- a/packages/sdk/test/unit/client.test.ts +++ b/packages/sdk/test/unit/client.test.ts @@ -14,6 +14,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { StitchToolClient } from "../../src/client.js"; +import type { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { ZodError } from "zod"; // Mock child_process for gcloud calls @@ -32,9 +33,7 @@ describe("StitchToolClient", () => { }); afterEach(() => { - // Restore original state globalThis.fetch = originalFetch; - delete (globalThis.fetch as any).__stitchPatched; process.env = originalEnv; }); @@ -221,6 +220,24 @@ describe("StitchToolClient", () => { expect(result).toEqual({ id: "123" }); }); + it("should handle legacy protocol format (toolResult instead of content)", async () => { + const client = createConnectedClient(); + client["client"].callTool = vi.fn().mockResolvedValue({ + toolResult: { projects: ["p1"] }, + }); + const result = await client.callTool("list_projects", {}); + expect(result).toEqual({ projects: ["p1"] }); + }); + + it("should handle legacy protocol format with undefined toolResult", async () => { + const client = createConnectedClient(); + client["client"].callTool = vi.fn().mockResolvedValue({ + toolResult: undefined, + }); + const result = await client.callTool("some_tool", {}); + expect(result).toBeNull(); + }); + it("should return raw text when JSON parse fails", async () => { const client = createConnectedClient(); client["client"].callTool = vi.fn().mockResolvedValue({ @@ -262,4 +279,131 @@ describe("StitchToolClient", () => { expect(connectCount).toBe(1); }); }); + + // ─── Cycle 5: network retry on transient failure ──────────── + describe("network retry", () => { + /** Helper: create a client with mocked doConnect for retry tests. */ + function createRetryableClient() { + const client = new StitchToolClient({ apiKey: "k" }); + client["isConnected"] = true; + client["doConnect"] = vi.fn(async () => { + client["isConnected"] = true; + }); + return client; + } + + it("should retry once after a network error on an idempotent tool", async () => { + const client = createRetryableClient(); + + let callCount = 0; + client["client"].callTool = vi.fn().mockImplementation(async () => { + callCount++; + if (callCount === 1) { + throw new TypeError("fetch failed"); + } + return { + isError: false, + content: [{ type: "text", text: '{"ok":true}' }], + }; + }); + + const result = await client.callTool("list_projects", {}); + expect(result).toEqual({ ok: true }); + expect(callCount).toBe(2); + expect(client["doConnect"]).toHaveBeenCalledTimes(1); + }); + + it("should not retry on StitchError (application-level error)", async () => { + const client = createRetryableClient(); + + client["client"].callTool = vi.fn().mockResolvedValue({ + isError: true, + content: [{ type: "text", text: "project not found" }], + }); + + await expect(client.callTool("list_projects", {})).rejects.toMatchObject({ + code: "NOT_FOUND", + }); + expect(client["client"].callTool).toHaveBeenCalledTimes(1); + }); + + it.each([ + "generate_screen_from_text", + "edit_screens", + "generate_variants", + "create_project", + ])("should not retry non-idempotent tool: %s", async (toolName) => { + const client = createRetryableClient(); + + client["client"].callTool = vi.fn().mockImplementation(async () => { + throw new TypeError("fetch failed"); + }); + + await expect(client.callTool(toolName, {})).rejects.toThrow( + "fetch failed", + ); + expect(client["client"].callTool).toHaveBeenCalledTimes(1); + }); + + it("should not retry unknown tools (safe default)", async () => { + const client = createRetryableClient(); + + client["client"].callTool = vi.fn().mockImplementation(async () => { + throw new TypeError("fetch failed"); + }); + + await expect(client.callTool("some_future_tool", {})).rejects.toThrow( + "fetch failed", + ); + expect(client["client"].callTool).toHaveBeenCalledTimes(1); + }); + + it("should close old transport before reconnecting", async () => { + const client = new StitchToolClient({ apiKey: "k" }); + client["isConnected"] = true; + + // Plant a fake old transport with a close spy + const oldTransportClose = vi + .fn<[], Promise>() + .mockResolvedValue(undefined); + client["transport"] = { close: oldTransportClose } as Pick< + StreamableHTTPClientTransport, + "close" + > as StreamableHTTPClientTransport; + + // Let real doConnect run so we test actual close logic, + // but stub client.connect to skip the network round-trip + client["client"].connect = vi.fn().mockResolvedValue(undefined); + + let callCount = 0; + client["client"].callTool = vi.fn().mockImplementation(async () => { + callCount++; + if (callCount === 1) throw new TypeError("fetch failed"); + return { + isError: false, + content: [{ type: "text", text: '{"ok":true}' }], + }; + }); + + await client.callTool("get_screen", {}); + expect(oldTransportClose).toHaveBeenCalledTimes(1); + }); + + it("should throw original error when retry also fails", async () => { + const client = createRetryableClient(); + const originalError = new TypeError("fetch failed: original"); + + let callCount = 0; + client["client"].callTool = vi.fn().mockImplementation(async () => { + callCount++; + if (callCount === 1) throw originalError; + throw new TypeError("fetch failed: retry"); + }); + + await expect(client.callTool("get_screen", {})).rejects.toBe( + originalError, + ); + expect(client["client"].callTool).toHaveBeenCalledTimes(2); + }); + }); }); diff --git a/packages/sdk/test/unit/sdk.test.ts b/packages/sdk/test/unit/sdk.test.ts index 4c50fab..e9c3373 100644 --- a/packages/sdk/test/unit/sdk.test.ts +++ b/packages/sdk/test/unit/sdk.test.ts @@ -169,7 +169,7 @@ describe("SDK Unit Tests", () => { describe("Stitch Class (Identity Map)", () => { it("should not have a getProject method — use project(id) instead", () => { const sdk = new Stitch(mockClient); - expect(typeof (sdk as any).getProject).toBe("undefined"); + expect(typeof (sdk as unknown as Record).getProject).toBe("undefined"); }); it("createProject should call create_project and return a Project", async () => { diff --git a/packages/sdk/test/unit/tools-adapter.test.ts b/packages/sdk/test/unit/tools-adapter.test.ts index f2e3c24..efdc8c6 100644 --- a/packages/sdk/test/unit/tools-adapter.test.ts +++ b/packages/sdk/test/unit/tools-adapter.test.ts @@ -48,19 +48,18 @@ describe("stitchTools()", () => { const tools = stitchTools(); for (const [, def] of Object.entries(tools)) { - const d = def as any; - expect(typeof d.description).toBe("string"); - expect(d.inputSchema).toBeDefined(); - expect(typeof d.execute).toBe("function"); + expect(typeof def.description).toBe("string"); + expect("inputSchema" in def).toBe(true); + expect(typeof def.execute).toBe("function"); } }); it("execute() delegates to callTool", async () => { const { stitchTools } = await import("../../src/tools-adapter.js"); const tools = stitchTools(); - const createProject = tools["create_project"] as any; + const createProject = tools["create_project"]; - await createProject.execute({ title: "Test Project" }); + await createProject.execute!({ title: "Test Project" }, { messages: [], toolCallId: "test" }); expect(mockCallTool).toHaveBeenCalledWith("create_project", { title: "Test Project" }); }); diff --git a/scripts/capture-tools.ts b/scripts/capture-tools.ts index 61b3166..868b11f 100644 --- a/scripts/capture-tools.ts +++ b/scripts/capture-tools.ts @@ -51,6 +51,7 @@ async function main() { url: baseUrl, name: "stitch-sdk-capture", version: "1.0.0", + protocolVersion: "2024-11-05", }, remoteTools: [], };