From a6757577ad3c0e9e582efb0b70f55f1013e7ff4a Mon Sep 17 00:00:00 2001 From: EMSHVAC Date: Fri, 22 Nov 2024 15:16:58 -0600 Subject: [PATCH] feat: implement prompt caching for direct Anthropic API integration --- extension/src/api/anthropic-direct.ts | 151 +++++++++++++------------- 1 file changed, 75 insertions(+), 76 deletions(-) diff --git a/extension/src/api/anthropic-direct.ts b/extension/src/api/anthropic-direct.ts index 7289077e..42516187 100644 --- a/extension/src/api/anthropic-direct.ts +++ b/extension/src/api/anthropic-direct.ts @@ -16,7 +16,10 @@ export class AnthropicDirectHandler implements ApiHandler { throw new Error("Anthropic API key is required") } this.client = new Anthropic({ - apiKey: options.apiKey + apiKey: options.apiKey, + defaultHeaders: { + 'anthropic-beta': 'prompt-caching-2024-07-31' + } }) } @@ -52,41 +55,80 @@ export class AnthropicDirectHandler implements ApiHandler { userMemory?: string, environmentDetails?: string ): AsyncIterableIterator { - // Create a new AbortController this.abortController = new AbortController() try { - // Build system prompt - const system: string[] = [] - system.push(systemPrompt.trim()) + // Build system content blocks with cache control + const systemBlocks: Anthropic.Beta.PromptCaching.Messages.PromptCachingBetaTextBlockParam[] = [] + + // Add system prompt + systemBlocks.push({ + type: "text", + text: systemPrompt.trim() + }) + + // Add custom instructions if (customInstructions?.trim()) { - system.push(customInstructions.trim()) + systemBlocks.push({ + type: "text", + text: customInstructions.trim() + }) } + + // Mark the last system block with cache_control + if (systemBlocks.length > 0) { + systemBlocks[systemBlocks.length - 1].cache_control = { type: "ephemeral" } + } + + // Add environment details with ephemeral cache control if (environmentDetails?.trim()) { - system.push(environmentDetails.trim()) + systemBlocks.push({ + type: "text", + text: environmentDetails.trim(), + cache_control: { type: "ephemeral" } + }) } - const systemPromptCombined = system.join("\n\n") - // Convert messages to Anthropic format - const anthropicMessages: Anthropic.Messages.MessageParam[] = messages.map(msg => { + // Convert messages to Anthropic format with cache control + const userMsgIndices = messages.reduce( + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), + [] as number[] + ) + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 + const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 + + const anthropicMessages = messages.map((msg, index) => { const { ts, ...message } = msg - if (typeof message.content === 'string') { - return { - ...message, - content: [{ type: 'text' as const, text: message.content }] - } - } + const isLastOrSecondLastUser = index === lastUserMsgIndex || index === secondLastMsgUserIndex + return { ...message, - content: message.content.map(block => { - if (typeof block === 'string') { - return { type: 'text' as const, text: block } - } - if ('type' in block && block.type === 'text') { - return block as Anthropic.TextBlockParam - } - return { type: 'text' as const, text: JSON.stringify(block) } - }) + content: typeof message.content === 'string' + ? [{ + type: 'text' as const, + text: message.content, + ...(isLastOrSecondLastUser && { cache_control: { type: "ephemeral" } }) + }] + : message.content.map((block, blockIndex) => { + if (typeof block === 'string') { + return { + type: 'text' as const, + text: block, + ...(isLastOrSecondLastUser && blockIndex === message.content.length - 1 && { cache_control: { type: "ephemeral" } }) + } + } + if ('type' in block && block.type === 'text') { + return { + ...block, + ...(isLastOrSecondLastUser && blockIndex === message.content.length - 1 && { cache_control: { type: "ephemeral" } }) + } as Anthropic.TextBlockParam + } + return { + type: 'text' as const, + text: JSON.stringify(block), + ...(isLastOrSecondLastUser && blockIndex === message.content.length - 1 && { cache_control: { type: "ephemeral" } }) + } + }) } }) @@ -101,19 +143,22 @@ export class AnthropicDirectHandler implements ApiHandler { // Start stream yield { code: 0, body: undefined } - // Create stream + // Create stream with prompt caching enabled const stream = await this.client.messages.create( { model: this.getModel().id, max_tokens: this.getModel().info.maxTokens, - system: systemPromptCombined, + system: systemBlocks, messages: anthropicMessages, temperature, top_p, stream: true }, { - signal: this.abortController.signal + signal: this.abortController.signal, + headers: { + 'anthropic-beta': 'prompt-caching-2024-07-31' + } } ) @@ -125,7 +170,6 @@ export class AnthropicDirectHandler implements ApiHandler { for await (const chunk of stream) { if (chunk.type === 'message_start') { - // Get initial token counts from message_start if (chunk.message?.usage) { usage.input_tokens = chunk.message.usage.input_tokens usage.output_tokens = chunk.message.usage.output_tokens @@ -145,23 +189,19 @@ export class AnthropicDirectHandler implements ApiHandler { .join('') yield { code: 2, body: { text } } } - // Update output tokens from message_delta if ('usage' in chunk && chunk.usage?.output_tokens) { usage.output_tokens = chunk.usage.output_tokens } } else if (chunk.type === 'message_stop') { if ('content' in chunk && Array.isArray(chunk.content)) { - // Calculate cost based on model pricing const model = this.getModel() const inputCost = (model.info.inputPrice / 1_000_000) * usage.input_tokens const outputCost = (model.info.outputPrice / 1_000_000) * usage.output_tokens const totalCost = inputCost + outputCost - // Check if response was cached using metadata if available const metadata = (chunk as any).metadata const isCached = metadata?.cached === true - // Set cache metrics based on caching status const cacheCreationInputTokens = isCached ? 0 : usage.input_tokens const cacheReadInputTokens = isCached ? usage.input_tokens : 0 @@ -188,8 +228,8 @@ export class AnthropicDirectHandler implements ApiHandler { userCredits: 0, inputTokens: usage.input_tokens, outputTokens: usage.output_tokens, - cacheCreationInputTokens: cacheCreationInputTokens, - cacheReadInputTokens: cacheReadInputTokens + cacheCreationInputTokens, + cacheReadInputTokens } } } @@ -198,56 +238,15 @@ export class AnthropicDirectHandler implements ApiHandler { } } - // Handle case where stream ends without a message_stop - if (content.length > 0) { - // Calculate cost based on model pricing - const model = this.getModel() - const inputCost = (model.info.inputPrice / 1_000_000) * usage.input_tokens - const outputCost = (model.info.outputPrice / 1_000_000) * usage.output_tokens - const totalCost = inputCost + outputCost - - yield { - code: 1, - body: { - anthropic: { - id: `stream_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, - type: 'message', - role: 'assistant', - content, - model: this.getModel().id, - stop_reason: 'end_turn', - stop_sequence: null, - usage: { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - cache_creation_input_tokens: usage.input_tokens, - cache_read_input_tokens: 0 - } - }, - internal: { - cost: totalCost, - userCredits: 0, - inputTokens: usage.input_tokens, - outputTokens: usage.output_tokens, - cacheCreationInputTokens: usage.input_tokens, - cacheReadInputTokens: 0 - } - } - } - return - } - throw new KoduError({ code: KODU_ERROR_CODES.NETWORK_REFUSED_TO_CONNECT }) } catch (error) { - // Don't throw errors on abort if (error instanceof Error && error.message === "aborted") { return } - // Handle other errors if (error instanceof Error) { if (error.message.includes("prompt is too long")) { yield {