Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 90 additions & 29 deletions Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,6 @@ public struct AnthropicLanguageModel: LanguageModel {
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
// For now, only String is supported
guard type == String.self else {
fatalError("AnthropicLanguageModel only supports generating String content")
}

let url = baseURL.appendingPathComponent("v1/messages")
let headers = buildHeaders()

Expand All @@ -330,11 +325,13 @@ public struct AnthropicLanguageModel: LanguageModel {
try convertToolToAnthropicFormat(tool)
}

let responseSchema = type == String.self ? nil : try convertSchemaToAnthropicFormat(Content.generationSchema)
let params = try createMessageParams(
model: model,
system: nil,
messages: session.transcript.toAnthropicMessages(),
tools: anthropicTools.isEmpty ? nil : anthropicTools,
responseSchema: responseSchema,
options: options
)

Expand Down Expand Up @@ -362,9 +359,10 @@ public struct AnthropicLanguageModel: LanguageModel {
if !calls.isEmpty {
entries.append(.toolCalls(Transcript.ToolCalls(calls)))
}
let empty = try emptyResponseContent(for: type)
return LanguageModelSession.Response(
content: "" as! Content,
rawContent: GeneratedContent(""),
content: empty.content,
rawContent: empty.rawContent,
transcriptEntries: ArraySlice(entries)
)
case .invocations(let invocations):
Expand All @@ -384,9 +382,19 @@ public struct AnthropicLanguageModel: LanguageModel {
}
}.joined()

if type == String.self {
return LanguageModelSession.Response(
content: text as! Content,
rawContent: GeneratedContent(text),
transcriptEntries: ArraySlice(entries)
)
}

let rawContent = try GeneratedContent(json: text)
let content = try Content(rawContent)
return LanguageModelSession.Response(
content: text as! Content,
rawContent: GeneratedContent(text),
content: content,
rawContent: rawContent,
transcriptEntries: ArraySlice(entries)
)
}
Expand All @@ -398,11 +406,6 @@ public struct AnthropicLanguageModel: LanguageModel {
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
// For now, only String is supported
guard type == String.self else {
fatalError("AnthropicLanguageModel only supports generating String content")
}

let url = baseURL.appendingPathComponent("v1/messages")

let stream: AsyncThrowingStream<LanguageModelSession.ResponseStream<Content>.Snapshot, any Error> = .init {
Expand All @@ -416,11 +419,14 @@ public struct AnthropicLanguageModel: LanguageModel {
try convertToolToAnthropicFormat(tool)
}

let responseSchema =
type == String.self ? nil : try convertSchemaToAnthropicFormat(Content.generationSchema)
var params = try createMessageParams(
model: model,
system: nil,
messages: session.transcript.toAnthropicMessages(),
tools: anthropicTools.isEmpty ? nil : anthropicTools,
responseSchema: responseSchema,
options: options
)
params["stream"] = .bool(true)
Expand All @@ -438,18 +444,26 @@ public struct AnthropicLanguageModel: LanguageModel {
)

var accumulatedText = ""
let expectsStructuredResponse = type != String.self

for try await event in events {
switch event {
case .contentBlockDelta(let delta):
if case .textDelta(let textDelta) = delta.delta {
accumulatedText += textDelta.text

// Yield snapshot with partially generated content
let raw = GeneratedContent(accumulatedText)
let content: Content.PartiallyGenerated = (accumulatedText as! Content)
.asPartiallyGenerated()
continuation.yield(.init(content: content, rawContent: raw))
if expectsStructuredResponse {
if let snapshot: LanguageModelSession.ResponseStream<Content>.Snapshot =
try? partialSnapshot(from: accumulatedText)
{
continuation.yield(snapshot)
}
} else {
let raw = GeneratedContent(accumulatedText)
let content: Content.PartiallyGenerated = (accumulatedText as! Content)
.asPartiallyGenerated()
continuation.yield(.init(content: content, rawContent: raw))
}
}
case .messageStop:
continuation.finish()
Expand Down Expand Up @@ -491,6 +505,7 @@ private func createMessageParams(
system: String?,
messages: [AnthropicMessage],
tools: [AnthropicTool]?,
responseSchema: JSONSchema?,
options: GenerationOptions
) throws -> [String: JSONValue] {
var params: [String: JSONValue] = [
Expand All @@ -505,6 +520,24 @@ private func createMessageParams(
if let tools, !tools.isEmpty {
params["tools"] = try JSONValue(tools)
}
if let responseSchema {
// Structured outputs: https://platform.claude.com/docs/en/build-with-claude/structured-outputs
let schemaValue = try JSONValue(responseSchema)
if case .object(let schemaObject) = schemaValue, schemaObject.isEmpty {
// Anthropic rejects empty schemas; omit output_config in this case.
} else {
params["output_config"] = .object(
[
"format": .object(
[
"type": .string("json_schema"),
"schema": schemaValue,
]
)
]
)
}
}
if let temperature = options.temperature {
params["temperature"] = .double(temperature)
}
Expand Down Expand Up @@ -577,6 +610,41 @@ private enum ToolResolutionOutcome {
case invocations([ToolInvocationResult])
}

private func emptyResponseContent<Content: Generable>(
for type: Content.Type
) throws -> (content: Content, rawContent: GeneratedContent) {
if type == String.self {
let raw = GeneratedContent("")
return ("" as! Content, raw)
}

let emptyObject = GeneratedContent(properties: [:])
if let content = try? Content(emptyObject) {
return (content, emptyObject)
}

let nullContent = GeneratedContent(kind: .null)
if let content = try? Content(nullContent) {
return (content, nullContent)
}

throw GeneratedContentConversionError.typeMismatch
}

private func partialSnapshot<Content: Generable>(
from accumulatedText: String
) throws -> LanguageModelSession.ResponseStream<Content>.Snapshot {
let raw = try GeneratedContent(json: accumulatedText)
let content = try Content.PartiallyGenerated(raw)
return .init(content: content, rawContent: raw)
}

private func convertSchemaToAnthropicFormat(_ schema: GenerationSchema) throws -> JSONSchema {
let resolvedSchema = schema.withResolvedRoot() ?? schema
let data = try JSONEncoder().encode(resolvedSchema)
return try JSONDecoder().decode(JSONSchema.self, from: data)
}

private func resolveToolUses(
_ toolUses: [AnthropicToolUse],
session: LanguageModelSession
Expand Down Expand Up @@ -631,10 +699,8 @@ private func resolveToolUses(
for (index, call) in transcriptCalls.enumerated() {
switch decisions[index] {
case .stop:
// This branch should be unreachable,
// because `.stop` returns during decision collection.
// Keep it as a defensive guard,
// in case that logic changes.
// This branch should be unreachable because `.stop` returns during decision collection.
// Keep it as a defensive guard in case that logic changes.
return .stop(calls: transcriptCalls)
case .provideOutput(let segments):
let output = Transcript.ToolOutput(
Expand Down Expand Up @@ -686,12 +752,7 @@ private func resolveToolUses(

// Convert our GenerationSchema into Anthropic's expected JSON Schema payload
private func convertToolToAnthropicFormat(_ tool: any Tool) throws -> AnthropicTool {
// Resolve the schema root to ensure it has a type field (Anthropic requirement)
let resolvedSchema = tool.parameters.withResolvedRoot() ?? tool.parameters

// Encode our internal schema then decode to JSONSchema type
let data = try JSONEncoder().encode(resolvedSchema)
let schema = try JSONDecoder().decode(JSONSchema.self, from: data)
let schema = try convertSchemaToAnthropicFormat(tool.parameters)
return AnthropicTool(name: tool.name, description: tool.description, inputSchema: schema)
}

Expand Down
36 changes: 36 additions & 0 deletions Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ import Testing

private let anthropicAPIKey: String? = ProcessInfo.processInfo.environment["ANTHROPIC_API_KEY"]

@Generable
private struct AnthropicStructuredForecast {
var summary: String
var temperatureCelsius: Int
}

@Suite("AnthropicLanguageModel", .enabled(if: anthropicAPIKey?.isEmpty == false))
struct AnthropicLanguageModelTests {
let model = AnthropicLanguageModel(
Expand Down Expand Up @@ -61,6 +67,24 @@ struct AnthropicLanguageModelTests {
#expect(!snapshots.last!.rawContent.jsonString.isEmpty)
}

@Test func streamingStructured() async throws {
let session = LanguageModelSession(model: model)

let stream = session.streamResponse(
to: "Provide a short weather forecast summary and a celsius temperature.",
generating: AnthropicStructuredForecast.self
)

var snapshots: [LanguageModelSession.ResponseStream<AnthropicStructuredForecast>.Snapshot] = []
for try await snapshot in stream {
snapshots.append(snapshot)
}

#expect(!snapshots.isEmpty)
#expect(!snapshots.last!.rawContent.jsonString.isEmpty)
#expect(!(snapshots.last!.content.summary ?? "").isEmpty)
}

@Test func withGenerationOptions() async throws {
let session = LanguageModelSession(model: model)

Expand All @@ -76,6 +100,18 @@ struct AnthropicLanguageModelTests {
#expect(!response.content.isEmpty)
}

@Test func structuredResponse() async throws {
let session = LanguageModelSession(model: model)

let response = try await session.respond(
to: "Summarize the weather with a short summary and a celsius temperature.",
generating: AnthropicStructuredForecast.self
)

#expect(!response.content.summary.isEmpty)
#expect(response.rawContent.jsonString.contains("summary"))
}

@Test func conversationContext() async throws {
let session = LanguageModelSession(model: model)

Expand Down