diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index da28255..8828751 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -429,11 +429,6 @@ public struct OpenAILanguageModel: LanguageModel { includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("OpenAILanguageModel only supports generating String content") - } - // Convert tools if any are available in the session let openAITools: [OpenAITool]? = { guard !session.tools.isEmpty else { return nil } @@ -450,6 +445,7 @@ public struct OpenAILanguageModel: LanguageModel { return try await respondWithChatCompletions( messages: session.transcript.toOpenAIMessages(), tools: openAITools, + generating: type, options: options, session: session ) @@ -457,6 +453,7 @@ public struct OpenAILanguageModel: LanguageModel { return try await respondWithResponses( messages: session.transcript.toOpenAIMessages(), tools: openAITools, + generating: type, options: options, session: session ) @@ -466,6 +463,7 @@ public struct OpenAILanguageModel: LanguageModel { private func respondWithChatCompletions( messages: [OpenAIMessage], tools: [OpenAITool]?, + generating type: Content.Type, options: GenerationOptions, session: LanguageModelSession ) async throws -> LanguageModelSession.Response where Content: Generable { @@ -476,10 +474,11 @@ public struct OpenAILanguageModel: LanguageModel { // Loop until no more tool calls while true { - let params = ChatCompletions.createRequestBody( + let params = try ChatCompletions.createRequestBody( model: model, messages: messages, tools: tools, + generating: type, options: options, stream: false ) @@ -496,10 +495,18 @@ public struct OpenAILanguageModel: LanguageModel { ) guard let choice = resp.choices.first else { - return LanguageModelSession.Response( - content: "" as! Content, - rawContent: GeneratedContent(""), - transcriptEntries: ArraySlice(entries) + throw OpenAILanguageModelError.noResponseGenerated + } + + if let refusalMessage = choice.message.refusal { + let refusalEntry = Transcript.Entry.response( + Transcript.Response(assetIDs: [], segments: [.text(.init(content: refusalMessage))]) + ) + throw LanguageModelSession.GenerationError.refusal( + LanguageModelSession.GenerationError.Refusal(transcriptEntries: [refusalEntry]), + LanguageModelSession.GenerationError.Context( + debugDescription: "OpenAI model refused to generate response: \(refusalMessage)" + ) ) } @@ -514,9 +521,10 @@ public struct OpenAILanguageModel: LanguageModel { if !calls.isEmpty { entries.append(.toolCalls(Transcript.ToolCalls(calls))) } + let empty = try emptyResponseContent(for: type) return LanguageModelSession.Response( - content: "" as! Content, - rawContent: GeneratedContent(""), + content: empty.content, + rawContent: empty.rawContent, transcriptEntries: ArraySlice(entries) ) case .invocations(let invocations): @@ -540,9 +548,20 @@ public struct OpenAILanguageModel: LanguageModel { text = choice.message.content ?? "" break } + + if type == String.self { + return LanguageModelSession.Response( + content: text as! Content, + rawContent: GeneratedContent(text), + transcriptEntries: ArraySlice(entries) + ) + } + + let generatedContent = try GeneratedContent(json: text) + let content = try type.init(generatedContent) return LanguageModelSession.Response( - content: text as! Content, - rawContent: GeneratedContent(text), + content: content, + rawContent: generatedContent, transcriptEntries: ArraySlice(entries) ) } @@ -550,21 +569,24 @@ public struct OpenAILanguageModel: LanguageModel { private func respondWithResponses( messages: [OpenAIMessage], tools: [OpenAITool]?, + generating type: Content.Type, options: GenerationOptions, session: LanguageModelSession ) async throws -> LanguageModelSession.Response where Content: Generable { var entries: [Transcript.Entry] = [] var text = "" + var lastOutput: [JSONValue]? var messages = messages let url = baseURL.appendingPathComponent("responses") // Loop until no more tool calls while true { - let params = Responses.createRequestBody( + let params = try Responses.createRequestBody( model: model, messages: messages, tools: tools, + generating: type, options: options, stream: false ) @@ -581,6 +603,7 @@ public struct OpenAILanguageModel: LanguageModel { ) let toolCalls = extractToolCallsFromOutput(resp.output) + lastOutput = resp.output if !toolCalls.isEmpty { if let output = resp.output { for msg in output { @@ -593,9 +616,10 @@ public struct OpenAILanguageModel: LanguageModel { if !calls.isEmpty { entries.append(.toolCalls(Transcript.ToolCalls(calls))) } + let empty = try emptyResponseContent(for: type) return LanguageModelSession.Response( - content: "" as! Content, - rawContent: GeneratedContent(""), + content: empty.content, + rawContent: empty.rawContent, transcriptEntries: ArraySlice(entries) ) case .invocations(let invocations): @@ -621,11 +645,25 @@ public struct OpenAILanguageModel: LanguageModel { break } - return LanguageModelSession.Response( - content: text as! Content, - rawContent: GeneratedContent(text), - transcriptEntries: ArraySlice(entries) - ) + + if type == String.self { + return LanguageModelSession.Response( + content: text as! Content, + rawContent: GeneratedContent(text), + transcriptEntries: ArraySlice(entries) + ) + } + + if let jsonString = extractJSONFromOutput(lastOutput) { + let generatedContent = try GeneratedContent(json: jsonString) + let content = try type.init(generatedContent) + return LanguageModelSession.Response( + content: content, + rawContent: generatedContent, + transcriptEntries: ArraySlice(entries) + ) + } + throw OpenAILanguageModelError.noResponseGenerated } public func streamResponse( @@ -635,11 +673,6 @@ public struct OpenAILanguageModel: LanguageModel { includeSchemaInPrompt: Bool, options: GenerationOptions ) -> sending LanguageModelSession.ResponseStream where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("OpenAILanguageModel only supports generating String content") - } - // Convert tools if any are available in the session let openAITools: [OpenAITool]? = { guard !session.tools.isEmpty else { return nil } @@ -653,120 +686,162 @@ public struct OpenAILanguageModel: LanguageModel { switch apiVariant { case .responses: - let params = Responses.createRequestBody( - model: model, - messages: session.transcript.toOpenAIMessages(), - tools: openAITools, - options: options, - stream: true - ) - let url = baseURL.appendingPathComponent("responses") let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in - let task = Task { @Sendable in - do { - let body = try JSONEncoder().encode(params) - - let events: AsyncThrowingStream = - urlSession.fetchEventStream( - .post, - url: url, - headers: [ - "Authorization": "Bearer \(tokenProvider())" - ], - body: body - ) + do { + let params = try Responses.createRequestBody( + model: model, + messages: session.transcript.toOpenAIMessages(), + tools: openAITools, + generating: type, + options: options, + stream: true + ) + let task = Task { @Sendable in + do { + let body = try JSONEncoder().encode(params) + + let events: AsyncThrowingStream = + urlSession.fetchEventStream( + .post, + url: url, + headers: [ + "Authorization": "Bearer \(tokenProvider())" + ], + body: body + ) - var accumulatedText = "" - - for try await event in events { - switch event { - case .outputTextDelta(let delta): - accumulatedText += delta - - // Yield snapshot with partially generated content - let raw = GeneratedContent(accumulatedText) - let content: Content.PartiallyGenerated = (accumulatedText as! Content) - .asPartiallyGenerated() - continuation.yield(.init(content: content, rawContent: raw)) - - case .toolCallCreated(_): - // Minimal streaming implementation ignores tool call events - break - case .toolCallDelta(_): - // Minimal streaming implementation ignores tool call deltas - break - case .completed(_): - continuation.finish() - case .ignored: - break + var accumulatedText = "" + + for try await event in events { + switch event { + case .outputTextDelta(let delta): + accumulatedText += delta + + var raw: GeneratedContent + let content: Content.PartiallyGenerated? + + if type == String.self { + raw = GeneratedContent(accumulatedText) + content = (accumulatedText as! Content).asPartiallyGenerated() + } else { + raw = + (try? GeneratedContent(json: accumulatedText)) + ?? GeneratedContent(accumulatedText) + if let parsed = try? type.init(raw) { + content = parsed.asPartiallyGenerated() + } else { + // Skip snapshots until the accumulated JSON parses. + content = nil + } + } + + if let content { + continuation.yield(.init(content: content, rawContent: raw)) + } + + case .toolCallCreated(_): + // Minimal streaming implementation ignores tool call events + break + case .toolCallDelta(_): + // Minimal streaming implementation ignores tool call deltas + break + case .completed(_): + continuation.finish() + case .ignored: + break + } } - } - continuation.finish() - } catch { - continuation.finish(throwing: error) + continuation.finish() + } catch { + continuation.finish(throwing: error) + } } + continuation.onTermination = { _ in task.cancel() } + } catch { + continuation.finish(throwing: error) } - continuation.onTermination = { _ in task.cancel() } } return LanguageModelSession.ResponseStream(stream: stream) case .chatCompletions: - let params = ChatCompletions.createRequestBody( - model: model, - messages: session.transcript.toOpenAIMessages(), - tools: openAITools, - options: options, - stream: true - ) - let url = baseURL.appendingPathComponent("chat/completions") let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in - let task = Task { @Sendable in - do { - let body = try JSONEncoder().encode(params) - - let events: AsyncThrowingStream = - urlSession.fetchEventStream( - .post, - url: url, - headers: [ - "Authorization": "Bearer \(tokenProvider())" - ], - body: body - ) - - var accumulatedText = "" - - for try await chunk in events { - if let choice = chunk.choices.first { - if let piece = choice.delta.content, !piece.isEmpty { - accumulatedText += piece + do { + let params = try ChatCompletions.createRequestBody( + model: model, + messages: session.transcript.toOpenAIMessages(), + tools: openAITools, + generating: type, + options: options, + stream: true + ) - let raw = GeneratedContent(accumulatedText) - let content: Content.PartiallyGenerated = (accumulatedText as! Content) - .asPartiallyGenerated() - continuation.yield(.init(content: content, rawContent: raw)) - } + let task = Task { @Sendable in + do { + let body = try JSONEncoder().encode(params) + + let events: AsyncThrowingStream = + urlSession.fetchEventStream( + .post, + url: url, + headers: [ + "Authorization": "Bearer \(tokenProvider())" + ], + body: body + ) - if choice.finishReason != nil { - continuation.finish() + var accumulatedText = "" + + for try await chunk in events { + if let choice = chunk.choices.first { + if let piece = choice.delta.content, !piece.isEmpty { + accumulatedText += piece + + var raw: GeneratedContent + let content: Content.PartiallyGenerated? + + if type == String.self { + raw = GeneratedContent(accumulatedText) + content = (accumulatedText as! Content).asPartiallyGenerated() + } else { + raw = + (try? GeneratedContent(json: accumulatedText)) + ?? GeneratedContent(accumulatedText) + if let parsed = try? type.init(raw) { + content = parsed.asPartiallyGenerated() + } else { + // Skip snapshots until the accumulated JSON parses. + content = nil + } + } + + if let content { + continuation.yield(.init(content: content, rawContent: raw)) + } + } + + if choice.finishReason != nil { + continuation.finish() + } } } - } - continuation.finish() - } catch { - continuation.finish(throwing: error) + continuation.finish() + } catch { + continuation.finish(throwing: error) + } } + continuation.onTermination = { _ in task.cancel() } + } catch { + continuation.finish(throwing: error) } - continuation.onTermination = { _ in task.cancel() } } return LanguageModelSession.ResponseStream(stream: stream) @@ -777,13 +852,14 @@ public struct OpenAILanguageModel: LanguageModel { // MARK: - API Variants private enum ChatCompletions { - static func createRequestBody( + static func createRequestBody( model: String, messages: [OpenAIMessage], tools: [OpenAITool]?, + generating type: Content.Type, options: GenerationOptions, stream: Bool - ) -> JSONValue { + ) throws -> JSONValue { var body: [String: JSONValue] = [ "model": .string(model), "messages": .array(messages.map { $0.jsonValue(for: .chatCompletions) }), @@ -794,6 +870,18 @@ private enum ChatCompletions { body["tools"] = .array(tools.map { $0.jsonValue(for: .chatCompletions) }) } + if type != String.self { + let jsonSchemaValue = try type.generationSchema.toJSONValueForOpenAIStrictMode() + body["response_format"] = .object([ + "type": .string("json_schema"), + "json_schema": .object([ + "name": .string("response_schema"), + "strict": .bool(true), + "schema": jsonSchemaValue, + ]), + ]) + } + if let temperature = options.temperature { body["temperature"] = .double(temperature) } @@ -899,11 +987,13 @@ private enum ChatCompletions { struct Message: Codable, Sendable { let role: String let content: String? + let refusal: String? let toolCalls: [OpenAIToolCall]? private enum CodingKeys: String, CodingKey { case role case content + case refusal case toolCalls = "tool_calls" } } @@ -911,13 +1001,14 @@ private enum ChatCompletions { } private enum Responses { - static func createRequestBody( + static func createRequestBody( model: String, messages: [OpenAIMessage], tools: [OpenAITool]?, + generating type: Content.Type, options: GenerationOptions, stream: Bool - ) -> JSONValue { + ) throws -> JSONValue { // Build input blocks from the user message content var body: [String: JSONValue] = [ @@ -1036,6 +1127,18 @@ private enum Responses { body["tools"] = .array(tools.map { $0.jsonValue(for: .responses) }) } + if type != String.self { + let jsonSchemaValue = try type.generationSchema.toJSONValueForOpenAIStrictMode() + body["text"] = .object([ + "format": .object([ + "type": .string("json_schema"), + "name": .string("response_schema"), + "strict": .bool(true), + "schema": jsonSchemaValue, + ]) + ]) + } + if let temperature = options.temperature { body["temperature"] = .double(temperature) } @@ -1636,6 +1739,19 @@ private func convertToolToOpenAIFormat(_ tool: any Tool) -> OpenAITool { return OpenAITool(type: "function", function: fn) } +private func emptyResponseContent( + for type: Content.Type +) throws -> (content: Content, rawContent: GeneratedContent) { + if type == String.self { + let raw = GeneratedContent("") + return ("" as! Content, raw) + } + + let raw = GeneratedContent(properties: [:]) + let content = try type.init(raw) + return (content, raw) +} + private func toGeneratedContent(_ jsonString: String?) throws -> GeneratedContent { guard let jsonString, !jsonString.isEmpty else { return GeneratedContent(properties: [:]) } return try GeneratedContent(json: jsonString) @@ -1666,6 +1782,30 @@ private func extractTextFromOutput(_ output: [JSONValue]?) -> String? { return textParts.isEmpty ? nil : textParts.joined() } +private func extractJSONFromOutput(_ output: [JSONValue]?) -> String? { + guard let output else { return nil } + + for block in output { + if case let .object(obj) = block, + case let .string(type)? = obj["type"], + type == "message", + case let .array(contentBlocks)? = obj["content"] + { + for contentBlock in contentBlocks { + if case let .object(contentObj) = contentBlock, + case let .string(contentType)? = contentObj["type"], + contentType == "output_text", + case let .string(jsonString)? = contentObj["text"] + { + return jsonString + } + } + } + } + + return nil +} + private func extractToolCallsFromOutput(_ output: [JSONValue]?) -> [OpenAIToolCall] { guard let output else { return [] } @@ -1758,3 +1898,52 @@ private func extractToolCallsFromOutput(_ output: [JSONValue]?) -> [OpenAIToolCa return toolCalls } + +// MARK: - Errors + +enum OpenAILanguageModelError: LocalizedError { + case noResponseGenerated + + var errorDescription: String? { + switch self { + case .noResponseGenerated: + return "No response was generated by the model" + } + } +} + +// MARK: - OpenAI Schema Helpers + +private extension GenerationSchema { + /// Converts this schema to a JSONValue with OpenAI strict mode requirements applied. + /// + /// OpenAI strict mode requires: + /// 1. `additionalProperties: false` at the root + /// 2. All properties (including optional ones) listed in `required` + func toJSONValueForOpenAIStrictMode() throws -> JSONValue { + let resolvedSchema = self.withResolvedRoot() ?? self + + let encoder = JSONEncoder() + encoder.userInfo[GenerationSchema.omitAdditionalPropertiesKey] = false + let schemaData = try encoder.encode(resolvedSchema) + let jsonSchema = try JSONDecoder().decode(JSONSchema.self, from: schemaData) + var jsonSchemaValue = try JSONValue(jsonSchema) + + if case .object(var schemaObj) = jsonSchemaValue { + schemaObj["additionalProperties"] = .bool(false) + + if case .object(let properties)? = schemaObj["properties"], + !properties.isEmpty + { + // OpenAI strict mode requires all properties to be listed as required, + // even if the underlying schema marks them optional. + let allPropertyNames = Array(properties.keys).sorted() + schemaObj["required"] = .array(allPropertyNames.map { .string($0) }) + } + + jsonSchemaValue = .object(schemaObj) + } + + return jsonSchemaValue + } +} diff --git a/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift b/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift index 12f15d3..2a7a4b8 100644 --- a/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift @@ -150,6 +150,95 @@ struct OpenAILanguageModelTests { } #expect(foundToolOutput) } + + @Suite("Structured Output") + struct StructuredOutputTests { + @Generable + struct Person { + @Guide(description: "The person's full name") + var name: String + + @Guide(description: "The person's age in years") + var age: Int + + @Guide(description: "The person's email address") + var email: String? + } + + @Generable + struct Book { + @Guide(description: "The book's title") + var title: String + + @Guide(description: "The book's author") + var author: String + + @Guide(description: "The publication year") + var year: Int + } + + private var model: OpenAILanguageModel { + OpenAILanguageModel(apiKey: openaiAPIKey!, model: "gpt-4o-mini", apiVariant: .chatCompletions) + } + + @Test func basicStructuredOutput() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a person named John Doe, age 30, email john@example.com", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty) + #expect(response.content.name.contains("John") || response.content.name.contains("Doe")) + #expect(response.content.age > 0) + #expect(response.content.age <= 100) + #expect(response.content.email != nil) + } + + @Test func structuredOutputWithOptionalField() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a person named Jane Smith, age 25, with no email", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty) + #expect(response.content.name.contains("Jane") || response.content.name.contains("Smith")) + #expect(response.content.age > 0) + #expect(response.content.age <= 100) + #expect(response.content.email == nil || response.content.email?.isEmpty == true) + } + + @Test func structuredOutputWithNestedTypes() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a book titled 'The Swift Programming Language' by 'Apple Inc.' published in 2024", + generating: Book.self + ) + + #expect(!response.content.title.isEmpty) + #expect(!response.content.author.isEmpty) + #expect(response.content.year >= 2020) + } + + @Test func streamingStructuredOutput() async throws { + let session = LanguageModelSession(model: model) + let stream = session.streamResponse( + to: "Generate a person named Alice, age 28, email alice@example.com", + generating: Person.self + ) + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in stream { + snapshots.append(snapshot) + } + + #expect(!snapshots.isEmpty) + let finalSnapshot = snapshots.last! + #expect((finalSnapshot.content.name?.isEmpty ?? true) == false) + #expect((finalSnapshot.content.age ?? 0) > 0) + } + } } @Suite("OpenAILanguageModel Responses API", .enabled(if: openaiAPIKey?.isEmpty == false)) @@ -280,5 +369,94 @@ struct OpenAILanguageModelTests { } #expect(foundToolOutput) } + + @Suite("Structured Output") + struct StructuredOutputTests { + @Generable + struct Person { + @Guide(description: "The person's full name") + var name: String + + @Guide(description: "The person's age in years") + var age: Int + + @Guide(description: "The person's email address") + var email: String? + } + + @Generable + struct Book { + @Guide(description: "The book's title") + var title: String + + @Guide(description: "The book's author") + var author: String + + @Guide(description: "The publication year") + var year: Int + } + + private var model: OpenAILanguageModel { + OpenAILanguageModel(apiKey: openaiAPIKey!, model: "gpt-4o-mini", apiVariant: .responses) + } + + @Test func basicStructuredOutput() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a person named John Doe, age 30, email john@example.com", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty) + #expect(response.content.name.contains("John") || response.content.name.contains("Doe")) + #expect(response.content.age > 0) + #expect(response.content.age <= 100) + #expect(response.content.email != nil) + } + + @Test func structuredOutputWithOptionalField() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a person named Jane Smith, age 25, with no email", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty) + #expect(response.content.name.contains("Jane") || response.content.name.contains("Smith")) + #expect(response.content.age > 0) + #expect(response.content.age <= 100) + #expect(response.content.email == nil || response.content.email?.isEmpty == true) + } + + @Test func structuredOutputWithNestedTypes() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a book titled 'The Swift Programming Language' by 'Apple Inc.' published in 2024", + generating: Book.self + ) + + #expect(!response.content.title.isEmpty) + #expect(!response.content.author.isEmpty) + #expect(response.content.year >= 2020) + } + + @Test func streamingStructuredOutput() async throws { + let session = LanguageModelSession(model: model) + let stream = session.streamResponse( + to: "Generate a person named Alice, age 28, email alice@example.com", + generating: Person.self + ) + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in stream { + snapshots.append(snapshot) + } + + #expect(!snapshots.isEmpty) + let finalSnapshot = snapshots.last! + #expect((finalSnapshot.content.name?.isEmpty ?? true) == false) + #expect((finalSnapshot.content.age ?? 0) > 0) + } + } } }