diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index da91c80..2da15a4 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -113,6 +113,10 @@ public struct GeminiLanguageModel: LanguageModel { /// When set to `.enabled`, the model will output valid JSON. /// When set to `.schema(_:)`, the model will output JSON /// conforming to the provided schema. + /// + /// - Note: When generating a non-`String` ``Generable`` type, the model + /// always uses the generated schema for structured output and ignores + /// this setting. public var jsonMode: JSONMode? /// Creates custom generation options for Gemini models. @@ -262,10 +266,6 @@ public struct GeminiLanguageModel: LanguageModel { includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - guard type == String.self else { - fatalError("GeminiLanguageModel only supports generating String content") - } - // Extract effective configuration from custom options or fall back to model defaults let customOptions = options[custom: GeminiLanguageModel.self] let effectiveThinking = customOptions?.thinking ?? _thinking @@ -287,6 +287,7 @@ public struct GeminiLanguageModel: LanguageModel { let params = try createGenerateContentParams( contents: transcript.toGeminiContent(), tools: geminiTools, + generating: type, options: options, thinking: effectiveThinking, jsonMode: effectiveJsonMode @@ -319,9 +320,10 @@ public struct GeminiLanguageModel: LanguageModel { if !calls.isEmpty { transcript.append(.toolCalls(Transcript.ToolCalls(calls))) } + let empty = try emptyResponseContent(for: type) return LanguageModelSession.Response( - content: "" as! Content, - rawContent: GeneratedContent(""), + content: empty.content, + rawContent: empty.rawContent, transcriptEntries: ArraySlice(transcript) ) case .invocations(let invocations): @@ -346,9 +348,19 @@ public struct GeminiLanguageModel: LanguageModel { } }.joined() ?? "" + if type == String.self { + return LanguageModelSession.Response( + content: text as! Content, + rawContent: GeneratedContent(text), + transcriptEntries: ArraySlice(transcript) + ) + } + + let generatedContent = try GeneratedContent(json: text) + let content = try type.init(generatedContent) return LanguageModelSession.Response( - content: text as! Content, - rawContent: GeneratedContent(text), + content: content, + rawContent: generatedContent, transcriptEntries: ArraySlice(transcript) ) } @@ -362,10 +374,6 @@ public struct GeminiLanguageModel: LanguageModel { includeSchemaInPrompt: Bool, options: GenerationOptions ) -> sending LanguageModelSession.ResponseStream where Content: Generable { - guard type == String.self else { - fatalError("GeminiLanguageModel only supports generating String content") - } - // Extract effective configuration from custom options or fall back to model defaults let customOptions = options[custom: GeminiLanguageModel.self] let effectiveThinking = customOptions?.thinking ?? _thinking @@ -390,6 +398,7 @@ public struct GeminiLanguageModel: LanguageModel { let params = try createGenerateContentParams( contents: session.transcript.toGeminiContent(), tools: geminiTools, + generating: type, options: options, thinking: effectiveThinking, jsonMode: effectiveJsonMode @@ -416,10 +425,27 @@ public struct GeminiLanguageModel: LanguageModel { if case .text(let textPart) = part { accumulatedText += textPart.text - let raw = GeneratedContent(accumulatedText) - let content: Content.PartiallyGenerated = (accumulatedText as! Content) - .asPartiallyGenerated() - continuation.yield(.init(content: content, rawContent: raw)) + var raw: GeneratedContent + let content: Content.PartiallyGenerated? + + if type == String.self { + raw = GeneratedContent(accumulatedText) + content = (accumulatedText as! Content).asPartiallyGenerated() + } else { + raw = + (try? GeneratedContent(json: accumulatedText)) + ?? GeneratedContent(accumulatedText) + if let parsed = try? type.init(raw) { + content = parsed.asPartiallyGenerated() + } else { + // Skip invalid partial JSON until it parses cleanly. + content = nil + } + } + + if let content { + continuation.yield(.init(content: content, rawContent: raw)) + } } } } @@ -451,7 +477,12 @@ public struct GeminiLanguageModel: LanguageModel { if !tools.isEmpty { let functionDeclarations: [GeminiFunctionDeclaration] = try tools.map { tool in - try convertToolToGeminiFormat(tool) + let schema = try convertSchemaToGeminiFormat(tool.parameters) + return GeminiFunctionDeclaration( + name: tool.name, + description: tool.description, + parameters: schema + ) } geminiTools.append(.functionDeclarations(functionDeclarations)) } @@ -473,9 +504,18 @@ public struct GeminiLanguageModel: LanguageModel { } } -private func createGenerateContentParams( +private func convertSchemaToGeminiFormat(_ schema: GenerationSchema) throws -> JSONSchema { + let resolvedSchema = schema.withResolvedRoot() ?? schema + let encoder = JSONEncoder() + encoder.userInfo[GenerationSchema.omitAdditionalPropertiesKey] = true + let data = try encoder.encode(resolvedSchema) + return try JSONDecoder().decode(JSONSchema.self, from: data) +} + +private func createGenerateContentParams( contents: [GeminiContent], tools: [GeminiTool]?, + generating type: Content.Type, options: GenerationOptions, thinking: GeminiLanguageModel.CustomGenerationOptions.Thinking, jsonMode: GeminiLanguageModel.CustomGenerationOptions.JSONMode? @@ -518,7 +558,11 @@ private func createGenerateContentParams( } generationConfig["thinkingConfig"] = .object(thinkingConfig) - if let jsonMode { + if type != String.self { + let schema = try convertSchemaToGeminiFormat(type.generationSchema) + generationConfig["responseMimeType"] = .string("application/json") + generationConfig["responseSchema"] = try JSONValue(schema) + } else if let jsonMode { switch jsonMode { case .disabled: break @@ -652,19 +696,23 @@ private func resolveFunctionCalls( return .invocations(results) } -private func convertToolToGeminiFormat(_ tool: any Tool) throws -> GeminiFunctionDeclaration { - let resolvedSchema = tool.parameters.withResolvedRoot() ?? tool.parameters - - let encoder = JSONEncoder() - encoder.userInfo[GenerationSchema.omitAdditionalPropertiesKey] = true - let data = try encoder.encode(resolvedSchema) - let schema = try JSONDecoder().decode(JSONSchema.self, from: data) +private func emptyResponseContent( + for type: Content.Type +) throws -> (content: Content, rawContent: GeneratedContent) { + if type == String.self { + let raw = GeneratedContent("") + return ("" as! Content, raw) + } - return GeminiFunctionDeclaration( - name: tool.name, - description: tool.description, - parameters: schema - ) + let rawEmpty = GeneratedContent(properties: [:]) + do { + let content = try type.init(rawEmpty) + return (content, rawEmpty) + } catch { + let rawNull = try GeneratedContent(json: "null") + let content = try type.init(rawNull) + return (content, rawNull) + } } private func toGeneratedContent(_ value: [String: JSONValue]?) throws -> GeneratedContent { diff --git a/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift b/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift index 4caa0f7..9375851 100644 --- a/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift @@ -178,4 +178,93 @@ struct GeminiLanguageModelTests { #expect(response.content.contains("Alice")) #expect(response.content.contains("30")) } + + @Suite("Structured Output") + struct StructuredOutputTests { + @Generable + struct Person { + @Guide(description: "The person's full name") + var name: String + + @Guide(description: "The person's age in years") + var age: Int + + @Guide(description: "The person's email address") + var email: String? + } + + @Generable + struct Book { + @Guide(description: "The book's title") + var title: String + + @Guide(description: "The book's author") + var author: String + + @Guide(description: "The publication year") + var year: Int + } + + private var model: GeminiLanguageModel { + GeminiLanguageModel(apiKey: geminiAPIKey!, model: "gemini-2.5-flash") + } + + @Test func basicStructuredOutput() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a person named John Doe, age 30, email john@example.com", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty) + #expect(response.content.name.contains("John") || response.content.name.contains("Doe")) + #expect(response.content.age > 0) + #expect(response.content.age <= 100) + #expect(response.content.email != nil) + } + + @Test func structuredOutputWithOptionalField() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a person named Jane Smith, age 25, with no email", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty) + #expect(response.content.name.contains("Jane") || response.content.name.contains("Smith")) + #expect(response.content.age > 0) + #expect(response.content.age <= 100) + #expect(response.content.email == nil || response.content.email?.isEmpty == true) + } + + @Test func structuredOutputWithNestedTypes() async throws { + let session = LanguageModelSession(model: model) + let response = try await session.respond( + to: "Generate a book titled 'The Swift Programming Language' by 'Apple Inc.' published in 2024", + generating: Book.self + ) + + #expect(!response.content.title.isEmpty) + #expect(!response.content.author.isEmpty) + #expect(response.content.year >= 2020) + } + + @Test func streamingStructuredOutput() async throws { + let session = LanguageModelSession(model: model) + let stream = session.streamResponse( + to: "Generate a person named Alice, age 28, email alice@example.com", + generating: Person.self + ) + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in stream { + snapshots.append(snapshot) + } + + #expect(!snapshots.isEmpty) + let finalSnapshot = snapshots.last! + #expect((finalSnapshot.content.name?.isEmpty ?? true) == false) + #expect((finalSnapshot.content.age ?? 0) > 0) + } + } }