diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index f88b68a..bd21a1f 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -317,11 +317,6 @@ public struct AnthropicLanguageModel: LanguageModel { includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("AnthropicLanguageModel only supports generating String content") - } - let url = baseURL.appendingPathComponent("v1/messages") let headers = buildHeaders() @@ -330,11 +325,13 @@ public struct AnthropicLanguageModel: LanguageModel { try convertToolToAnthropicFormat(tool) } + let responseSchema = type == String.self ? nil : try convertSchemaToAnthropicFormat(Content.generationSchema) let params = try createMessageParams( model: model, system: nil, messages: session.transcript.toAnthropicMessages(), tools: anthropicTools.isEmpty ? nil : anthropicTools, + responseSchema: responseSchema, options: options ) @@ -362,9 +359,10 @@ public struct AnthropicLanguageModel: LanguageModel { if !calls.isEmpty { entries.append(.toolCalls(Transcript.ToolCalls(calls))) } + let empty = try emptyResponseContent(for: type) return LanguageModelSession.Response( - content: "" as! Content, - rawContent: GeneratedContent(""), + content: empty.content, + rawContent: empty.rawContent, transcriptEntries: ArraySlice(entries) ) case .invocations(let invocations): @@ -384,9 +382,19 @@ public struct AnthropicLanguageModel: LanguageModel { } }.joined() + if type == String.self { + return LanguageModelSession.Response( + content: text as! Content, + rawContent: GeneratedContent(text), + transcriptEntries: ArraySlice(entries) + ) + } + + let rawContent = try GeneratedContent(json: text) + let content = try Content(rawContent) return LanguageModelSession.Response( - content: text as! Content, - rawContent: GeneratedContent(text), + content: content, + rawContent: rawContent, transcriptEntries: ArraySlice(entries) ) } @@ -398,11 +406,6 @@ public struct AnthropicLanguageModel: LanguageModel { includeSchemaInPrompt: Bool, options: GenerationOptions ) -> sending LanguageModelSession.ResponseStream where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("AnthropicLanguageModel only supports generating String content") - } - let url = baseURL.appendingPathComponent("v1/messages") let stream: AsyncThrowingStream.Snapshot, any Error> = .init { @@ -416,11 +419,14 @@ public struct AnthropicLanguageModel: LanguageModel { try convertToolToAnthropicFormat(tool) } + let responseSchema = + type == String.self ? nil : try convertSchemaToAnthropicFormat(Content.generationSchema) var params = try createMessageParams( model: model, system: nil, messages: session.transcript.toAnthropicMessages(), tools: anthropicTools.isEmpty ? nil : anthropicTools, + responseSchema: responseSchema, options: options ) params["stream"] = .bool(true) @@ -438,6 +444,7 @@ public struct AnthropicLanguageModel: LanguageModel { ) var accumulatedText = "" + let expectsStructuredResponse = type != String.self for try await event in events { switch event { @@ -445,11 +452,18 @@ public struct AnthropicLanguageModel: LanguageModel { if case .textDelta(let textDelta) = delta.delta { accumulatedText += textDelta.text - // Yield snapshot with partially generated content - let raw = GeneratedContent(accumulatedText) - let content: Content.PartiallyGenerated = (accumulatedText as! Content) - .asPartiallyGenerated() - continuation.yield(.init(content: content, rawContent: raw)) + if expectsStructuredResponse { + if let snapshot: LanguageModelSession.ResponseStream.Snapshot = + try? partialSnapshot(from: accumulatedText) + { + continuation.yield(snapshot) + } + } else { + let raw = GeneratedContent(accumulatedText) + let content: Content.PartiallyGenerated = (accumulatedText as! Content) + .asPartiallyGenerated() + continuation.yield(.init(content: content, rawContent: raw)) + } } case .messageStop: continuation.finish() @@ -491,6 +505,7 @@ private func createMessageParams( system: String?, messages: [AnthropicMessage], tools: [AnthropicTool]?, + responseSchema: JSONSchema?, options: GenerationOptions ) throws -> [String: JSONValue] { var params: [String: JSONValue] = [ @@ -505,6 +520,24 @@ private func createMessageParams( if let tools, !tools.isEmpty { params["tools"] = try JSONValue(tools) } + if let responseSchema { + // Structured outputs: https://platform.claude.com/docs/en/build-with-claude/structured-outputs + let schemaValue = try JSONValue(responseSchema) + if case .object(let schemaObject) = schemaValue, schemaObject.isEmpty { + // Anthropic rejects empty schemas; omit output_config in this case. + } else { + params["output_config"] = .object( + [ + "format": .object( + [ + "type": .string("json_schema"), + "schema": schemaValue, + ] + ) + ] + ) + } + } if let temperature = options.temperature { params["temperature"] = .double(temperature) } @@ -577,6 +610,41 @@ private enum ToolResolutionOutcome { case invocations([ToolInvocationResult]) } +private func emptyResponseContent( + for type: Content.Type +) throws -> (content: Content, rawContent: GeneratedContent) { + if type == String.self { + let raw = GeneratedContent("") + return ("" as! Content, raw) + } + + let emptyObject = GeneratedContent(properties: [:]) + if let content = try? Content(emptyObject) { + return (content, emptyObject) + } + + let nullContent = GeneratedContent(kind: .null) + if let content = try? Content(nullContent) { + return (content, nullContent) + } + + throw GeneratedContentConversionError.typeMismatch +} + +private func partialSnapshot( + from accumulatedText: String +) throws -> LanguageModelSession.ResponseStream.Snapshot { + let raw = try GeneratedContent(json: accumulatedText) + let content = try Content.PartiallyGenerated(raw) + return .init(content: content, rawContent: raw) +} + +private func convertSchemaToAnthropicFormat(_ schema: GenerationSchema) throws -> JSONSchema { + let resolvedSchema = schema.withResolvedRoot() ?? schema + let data = try JSONEncoder().encode(resolvedSchema) + return try JSONDecoder().decode(JSONSchema.self, from: data) +} + private func resolveToolUses( _ toolUses: [AnthropicToolUse], session: LanguageModelSession @@ -631,10 +699,8 @@ private func resolveToolUses( for (index, call) in transcriptCalls.enumerated() { switch decisions[index] { case .stop: - // This branch should be unreachable, - // because `.stop` returns during decision collection. - // Keep it as a defensive guard, - // in case that logic changes. + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( @@ -686,12 +752,7 @@ private func resolveToolUses( // Convert our GenerationSchema into Anthropic's expected JSON Schema payload private func convertToolToAnthropicFormat(_ tool: any Tool) throws -> AnthropicTool { - // Resolve the schema root to ensure it has a type field (Anthropic requirement) - let resolvedSchema = tool.parameters.withResolvedRoot() ?? tool.parameters - - // Encode our internal schema then decode to JSONSchema type - let data = try JSONEncoder().encode(resolvedSchema) - let schema = try JSONDecoder().decode(JSONSchema.self, from: data) + let schema = try convertSchemaToAnthropicFormat(tool.parameters) return AnthropicTool(name: tool.name, description: tool.description, inputSchema: schema) } diff --git a/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift b/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift index ece1ec8..705418e 100644 --- a/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift @@ -5,6 +5,12 @@ import Testing private let anthropicAPIKey: String? = ProcessInfo.processInfo.environment["ANTHROPIC_API_KEY"] +@Generable +private struct AnthropicStructuredForecast { + var summary: String + var temperatureCelsius: Int +} + @Suite("AnthropicLanguageModel", .enabled(if: anthropicAPIKey?.isEmpty == false)) struct AnthropicLanguageModelTests { let model = AnthropicLanguageModel( @@ -61,6 +67,24 @@ struct AnthropicLanguageModelTests { #expect(!snapshots.last!.rawContent.jsonString.isEmpty) } + @Test func streamingStructured() async throws { + let session = LanguageModelSession(model: model) + + let stream = session.streamResponse( + to: "Provide a short weather forecast summary and a celsius temperature.", + generating: AnthropicStructuredForecast.self + ) + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in stream { + snapshots.append(snapshot) + } + + #expect(!snapshots.isEmpty) + #expect(!snapshots.last!.rawContent.jsonString.isEmpty) + #expect(!(snapshots.last!.content.summary ?? "").isEmpty) + } + @Test func withGenerationOptions() async throws { let session = LanguageModelSession(model: model) @@ -76,6 +100,18 @@ struct AnthropicLanguageModelTests { #expect(!response.content.isEmpty) } + @Test func structuredResponse() async throws { + let session = LanguageModelSession(model: model) + + let response = try await session.respond( + to: "Summarize the weather with a short summary and a celsius temperature.", + generating: AnthropicStructuredForecast.self + ) + + #expect(!response.content.summary.isEmpty) + #expect(response.rawContent.jsonString.contains("summary")) + } + @Test func conversationContext() async throws { let session = LanguageModelSession(model: model)