Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 75 additions & 76 deletions extension/src/api/anthropic-direct.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ export class AnthropicDirectHandler implements ApiHandler {
throw new Error("Anthropic API key is required")
}
this.client = new Anthropic({
apiKey: options.apiKey
apiKey: options.apiKey,
defaultHeaders: {
'anthropic-beta': 'prompt-caching-2024-07-31'
}
})
}

Expand Down Expand Up @@ -52,41 +55,80 @@ export class AnthropicDirectHandler implements ApiHandler {
userMemory?: string,
environmentDetails?: string
): AsyncIterableIterator<koduSSEResponse> {
// Create a new AbortController
this.abortController = new AbortController()

try {
// Build system prompt
const system: string[] = []
system.push(systemPrompt.trim())
// Build system content blocks with cache control
const systemBlocks: Anthropic.Beta.PromptCaching.Messages.PromptCachingBetaTextBlockParam[] = []

// Add system prompt
systemBlocks.push({
type: "text",
text: systemPrompt.trim()
})

// Add custom instructions
if (customInstructions?.trim()) {
system.push(customInstructions.trim())
systemBlocks.push({
type: "text",
text: customInstructions.trim()
})
}

// Mark the last system block with cache_control
if (systemBlocks.length > 0) {
systemBlocks[systemBlocks.length - 1].cache_control = { type: "ephemeral" }
}

// Add environment details with ephemeral cache control
if (environmentDetails?.trim()) {
system.push(environmentDetails.trim())
systemBlocks.push({
type: "text",
text: environmentDetails.trim(),
cache_control: { type: "ephemeral" }
})
}
const systemPromptCombined = system.join("\n\n")

// Convert messages to Anthropic format
const anthropicMessages: Anthropic.Messages.MessageParam[] = messages.map(msg => {
// Convert messages to Anthropic format with cache control
const userMsgIndices = messages.reduce(
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
[] as number[]
)
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1

const anthropicMessages = messages.map((msg, index) => {
const { ts, ...message } = msg
if (typeof message.content === 'string') {
return {
...message,
content: [{ type: 'text' as const, text: message.content }]
}
}
const isLastOrSecondLastUser = index === lastUserMsgIndex || index === secondLastMsgUserIndex

return {
...message,
content: message.content.map(block => {
if (typeof block === 'string') {
return { type: 'text' as const, text: block }
}
if ('type' in block && block.type === 'text') {
return block as Anthropic.TextBlockParam
}
return { type: 'text' as const, text: JSON.stringify(block) }
})
content: typeof message.content === 'string'
? [{
type: 'text' as const,
text: message.content,
...(isLastOrSecondLastUser && { cache_control: { type: "ephemeral" } })
}]
: message.content.map((block, blockIndex) => {
if (typeof block === 'string') {
return {
type: 'text' as const,
text: block,
...(isLastOrSecondLastUser && blockIndex === message.content.length - 1 && { cache_control: { type: "ephemeral" } })
}
}
if ('type' in block && block.type === 'text') {
return {
...block,
...(isLastOrSecondLastUser && blockIndex === message.content.length - 1 && { cache_control: { type: "ephemeral" } })
} as Anthropic.TextBlockParam
}
return {
type: 'text' as const,
text: JSON.stringify(block),
...(isLastOrSecondLastUser && blockIndex === message.content.length - 1 && { cache_control: { type: "ephemeral" } })
}
})
}
})

Expand All @@ -101,19 +143,22 @@ export class AnthropicDirectHandler implements ApiHandler {
// Start stream
yield { code: 0, body: undefined }

// Create stream
// Create stream with prompt caching enabled
const stream = await this.client.messages.create(
{
model: this.getModel().id,
max_tokens: this.getModel().info.maxTokens,
system: systemPromptCombined,
system: systemBlocks,
messages: anthropicMessages,
temperature,
top_p,
stream: true
},
{
signal: this.abortController.signal
signal: this.abortController.signal,
headers: {
'anthropic-beta': 'prompt-caching-2024-07-31'
}
}
)

Expand All @@ -125,7 +170,6 @@ export class AnthropicDirectHandler implements ApiHandler {

for await (const chunk of stream) {
if (chunk.type === 'message_start') {
// Get initial token counts from message_start
if (chunk.message?.usage) {
usage.input_tokens = chunk.message.usage.input_tokens
usage.output_tokens = chunk.message.usage.output_tokens
Expand All @@ -145,23 +189,19 @@ export class AnthropicDirectHandler implements ApiHandler {
.join('')
yield { code: 2, body: { text } }
}
// Update output tokens from message_delta
if ('usage' in chunk && chunk.usage?.output_tokens) {
usage.output_tokens = chunk.usage.output_tokens
}
} else if (chunk.type === 'message_stop') {
if ('content' in chunk && Array.isArray(chunk.content)) {
// Calculate cost based on model pricing
const model = this.getModel()
const inputCost = (model.info.inputPrice / 1_000_000) * usage.input_tokens
const outputCost = (model.info.outputPrice / 1_000_000) * usage.output_tokens
const totalCost = inputCost + outputCost

// Check if response was cached using metadata if available
const metadata = (chunk as any).metadata
const isCached = metadata?.cached === true

// Set cache metrics based on caching status
const cacheCreationInputTokens = isCached ? 0 : usage.input_tokens
const cacheReadInputTokens = isCached ? usage.input_tokens : 0

Expand All @@ -188,8 +228,8 @@ export class AnthropicDirectHandler implements ApiHandler {
userCredits: 0,
inputTokens: usage.input_tokens,
outputTokens: usage.output_tokens,
cacheCreationInputTokens: cacheCreationInputTokens,
cacheReadInputTokens: cacheReadInputTokens
cacheCreationInputTokens,
cacheReadInputTokens
}
}
}
Expand All @@ -198,56 +238,15 @@ export class AnthropicDirectHandler implements ApiHandler {
}
}

// Handle case where stream ends without a message_stop
if (content.length > 0) {
// Calculate cost based on model pricing
const model = this.getModel()
const inputCost = (model.info.inputPrice / 1_000_000) * usage.input_tokens
const outputCost = (model.info.outputPrice / 1_000_000) * usage.output_tokens
const totalCost = inputCost + outputCost

yield {
code: 1,
body: {
anthropic: {
id: `stream_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
type: 'message',
role: 'assistant',
content,
model: this.getModel().id,
stop_reason: 'end_turn',
stop_sequence: null,
usage: {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
cache_creation_input_tokens: usage.input_tokens,
cache_read_input_tokens: 0
}
},
internal: {
cost: totalCost,
userCredits: 0,
inputTokens: usage.input_tokens,
outputTokens: usage.output_tokens,
cacheCreationInputTokens: usage.input_tokens,
cacheReadInputTokens: 0
}
}
}
return
}

throw new KoduError({
code: KODU_ERROR_CODES.NETWORK_REFUSED_TO_CONNECT
})

} catch (error) {
// Don't throw errors on abort
if (error instanceof Error && error.message === "aborted") {
return
}

// Handle other errors
if (error instanceof Error) {
if (error.message.includes("prompt is too long")) {
yield {
Expand Down