diff --git a/README.md b/README.md index 99a0569..b8429ff 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,36 @@ let response = try await session.respond { print(response.content) ``` +To observe or control tool execution, assign a delegate on the session: + +```swift +actor ToolExecutionObserver: ToolExecutionDelegate { + func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async { + print("Generated tool calls: \(toolCalls)") + } + + func toolCallDecision( + for toolCall: Transcript.ToolCall, + in session: LanguageModelSession + ) async -> ToolExecutionDecision { + // Return .stop to halt after tool calls, or .provideOutput(...) to bypass execution. + // This is a good place to ask the user for confirmation (for example, in a modal dialog). + .execute + } + + func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async { + print("Executed tool call: \(toolCall)") + } +} + +let session = LanguageModelSession(model: model, tools: [WeatherTool()]) +session.toolExecutionDelegate = ToolExecutionObserver() +``` + ## Features ### Supported Providers diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index 34f4ef8..1b56949 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -10,6 +10,16 @@ public final class LanguageModelSession: @unchecked Sendable { public let tools: [any Tool] public let instructions: Instructions? + /// A delegate that observes and controls tool execution. + /// + /// Set this property to intercept tool calls, provide custom output, + /// or stop after tool calls are generated. + /// + /// - Note: This property is exclusive to AnyLanguageModel + /// and using it means your code is no longer drop-in compatible + /// with the Foundation Models framework. + @ObservationIgnored public var toolExecutionDelegate: (any ToolExecutionDelegate)? + @ObservationIgnored private let respondingState = RespondingState() public convenience init( diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index 509f5ef..f88b68a 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -356,11 +356,23 @@ public struct AnthropicLanguageModel: LanguageModel { } if !toolUses.isEmpty { - let invocations = try await resolveToolUses(toolUses, session: session) - if !invocations.isEmpty { - entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - for invocation in invocations { - entries.append(.toolOutput(invocation.output)) + let resolution = try await resolveToolUses(toolUses, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) + for invocation in invocations { + entries.append(.toolOutput(invocation.output)) + } } } } @@ -560,11 +572,16 @@ private struct ToolInvocationResult { let output: Transcript.ToolOutput } +private enum ToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([ToolInvocationResult]) +} + private func resolveToolUses( _ toolUses: [AnthropicToolUse], session: LanguageModelSession -) async throws -> [ToolInvocationResult] { - if toolUses.isEmpty { return [] } +) async throws -> ToolResolutionOutcome { + if toolUses.isEmpty { return .invocations([]) } var toolsByName: [String: any Tool] = [:] for tool in session.tools { @@ -573,43 +590,98 @@ private func resolveToolUses( } } - var results: [ToolInvocationResult] = [] - results.reserveCapacity(toolUses.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(toolUses.count) for use in toolUses { let args = try toGeneratedContent(use.input) let callID = use.id - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: use.name, - arguments: args - ) - - guard let tool = toolsByName[use.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(use.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: use.name, - segments: [message] + arguments: args ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [ToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + // This branch should be unreachable, + // because `.stop` returns during decision collection. + // Keep it as a defensive guard, + // in case that logic changes. + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: tool.name, - toolName: tool.name, + id: call.id, + toolName: call.toolName, segments: segments ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: call.id, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } // Convert our GenerationSchema into Anthropic's expected JSON Schema payload diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index 886bbe2..da91c80 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -313,17 +313,29 @@ public struct GeminiLanguageModel: LanguageModel { if !functionCalls.isEmpty { // Resolve function calls - let invocations = try await resolveFunctionCalls(functionCalls, session: session) - if !invocations.isEmpty { - transcript.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) + let resolution = try await resolveFunctionCalls(functionCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + transcript.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(transcript) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + transcript.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - for invocation in invocations { - transcript.append(.toolOutput(invocation.output)) + for invocation in invocations { + transcript.append(.toolOutput(invocation.output)) + } } - } - // Continue the loop to send the next request with tool results - continue + // Continue the loop to send the next request with tool results + continue + } } else { // No function calls, extract final text and return let text = @@ -530,11 +542,16 @@ private struct ToolInvocationResult { let output: Transcript.ToolOutput } +private enum ToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([ToolInvocationResult]) +} + private func resolveFunctionCalls( _ functionCalls: [GeminiFunctionCall], session: LanguageModelSession -) async throws -> [ToolInvocationResult] { - if functionCalls.isEmpty { return [] } +) async throws -> ToolResolutionOutcome { + if functionCalls.isEmpty { return .invocations([]) } var toolsByName: [String: any Tool] = [:] for tool in session.tools { @@ -543,43 +560,96 @@ private func resolveFunctionCalls( } } - var results: [ToolInvocationResult] = [] - results.reserveCapacity(functionCalls.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(functionCalls.count) for call in functionCalls { let args = try toGeneratedContent(call.args) let callID = UUID().uuidString - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: call.name, - arguments: args - ) - - guard let tool = toolsByName[call.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: call.name, - segments: [message] + arguments: args ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [ToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: tool.name, - toolName: tool.name, + id: call.id, + toolName: call.toolName, segments: segments ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: call.id, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } private func convertToolToGeminiFormat(_ tool: any Tool) throws -> GeminiFunctionDeclaration { diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index a7ea74f..7b781b2 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -276,21 +276,33 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { - let invocations = try await resolveToolCalls(collectedToolCalls, session: session) - if !invocations.isEmpty { - allEntries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + allEntries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(allEntries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + allEntries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - // Execute each tool and add results to chat - for invocation in invocations { - allEntries.append(.toolOutput(invocation.output)) + // Execute each tool and add results to chat + for invocation in invocations { + allEntries.append(.toolOutput(invocation.output)) - // Convert tool output to JSON string for MLX - let toolResultJSON = toolOutputToJSON(invocation.output) - chat.append(.tool(toolResultJSON)) - } + // Convert tool output to JSON string for MLX + let toolResultJSON = toolOutputToJSON(invocation.output) + chat.append(.tool(toolResultJSON)) + } - // Continue loop to generate with tool results - continue + // Continue loop to generate with tool results + continue + } } } @@ -604,11 +616,16 @@ import Foundation let output: Transcript.ToolOutput } + private enum ToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([ToolInvocationResult]) + } + private func resolveToolCalls( _ toolCalls: [MLXLMCommon.ToolCall], session: LanguageModelSession - ) async throws -> [ToolInvocationResult] { - if toolCalls.isEmpty { return [] } + ) async throws -> ToolResolutionOutcome { + if toolCalls.isEmpty { return .invocations([]) } var toolsByName: [String: any Tool] = [:] for tool in session.tools { @@ -617,43 +634,96 @@ import Foundation } } - var results: [ToolInvocationResult] = [] - results.reserveCapacity(toolCalls.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(toolCalls.count) for call in toolCalls { let args = try toGeneratedContent(call.function.arguments) let callID = UUID().uuidString - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: call.function.name, - arguments: args - ) - - guard let tool = toolsByName[call.function.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.function.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: call.function.name, - segments: [message] + arguments: args ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [ToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: tool.name, - toolName: tool.name, + id: call.id, + toolName: call.toolName, segments: segments ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: call.id, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } private func toGeneratedContent(_ args: [String: MLXLMCommon.JSONValue]) throws -> GeneratedContent { diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index 12ec53f..4f49d13 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -112,11 +112,23 @@ public struct OllamaLanguageModel: LanguageModel { var entries: [Transcript.Entry] = [] if let toolCalls = chatResponse.message.toolCalls, !toolCalls.isEmpty { - let invocations = try await resolveToolCalls(toolCalls, session: session) - if !invocations.isEmpty { - entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - for invocation in invocations { - entries.append(.toolOutput(invocation.output)) + let resolution = try await resolveToolCalls(toolCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) + for invocation in invocations { + entries.append(.toolOutput(invocation.output)) + } } } } @@ -216,12 +228,17 @@ private struct ToolInvocationResult { let output: Transcript.ToolOutput } +private enum ToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([ToolInvocationResult]) +} + private func resolveToolCalls( _ toolCalls: [OllamaToolCall], session: LanguageModelSession -) async throws -> [ToolInvocationResult] { +) async throws -> ToolResolutionOutcome { if toolCalls.isEmpty { - return [] + return .invocations([]) } var toolsByName: [String: any Tool] = [:] @@ -231,43 +248,96 @@ private func resolveToolCalls( } } - var results: [ToolInvocationResult] = [] - results.reserveCapacity(toolCalls.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(toolCalls.count) for call in toolCalls { let args = try toGeneratedContent(call.function.arguments) let callID = call.id ?? UUID().uuidString - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: call.function.name, - arguments: args - ) - - guard let tool = toolsByName[call.function.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.function.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: call.function.name, - segments: [message] + arguments: args ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [ToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: tool.name, - toolName: tool.name, + id: call.id, + toolName: call.toolName, segments: segments ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: call.id, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } // MARK: - Conversions diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index 6299ce0..da28255 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -508,20 +508,32 @@ public struct OpenAILanguageModel: LanguageModel { if let value = try? JSONValue(toolCallMessage) { messages.append(OpenAIMessage(role: .raw(rawContent: value), content: .text(""))) } - let invocations = try await resolveToolCalls(toolCalls, session: session) - if !invocations.isEmpty { - entries.append(.toolCalls(Transcript.ToolCalls(invocations.map { $0.call }))) - for invocation in invocations { - let output = invocation.output - entries.append(.toolOutput(output)) - messages.append( - OpenAIMessage( - role: .tool(id: invocation.call.id), - content: .text(convertSegmentsToToolContentString(output.segments)) + let resolution = try await resolveToolCalls(toolCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(invocations.map { $0.call }))) + for invocation in invocations { + let output = invocation.output + entries.append(.toolOutput(output)) + messages.append( + OpenAIMessage( + role: .tool(id: invocation.call.id), + content: .text(convertSegmentsToToolContentString(output.segments)) + ) ) - ) + } + continue } - continue } } @@ -575,21 +587,33 @@ public struct OpenAILanguageModel: LanguageModel { messages.append(OpenAIMessage(role: .raw(rawContent: msg), content: .text(""))) } } - let invocations = try await resolveToolCalls(toolCalls, session: session) - if !invocations.isEmpty { - entries.append(.toolCalls(Transcript.ToolCalls(invocations.map { $0.call }))) - - for invocation in invocations { - let output = invocation.output - entries.append(.toolOutput(output)) - messages.append( - OpenAIMessage( - role: .tool(id: invocation.call.id), - content: .text(convertSegmentsToToolContentString(output.segments)) + let resolution = try await resolveToolCalls(toolCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(invocations.map { $0.call }))) + + for invocation in invocations { + let output = invocation.output + entries.append(.toolOutput(output)) + messages.append( + OpenAIMessage( + role: .tool(id: invocation.call.id), + content: .text(convertSegmentsToToolContentString(output.segments)) + ) ) - ) + } + continue } - continue } } @@ -1478,11 +1502,16 @@ private struct OpenAIToolInvocationResult { let output: Transcript.ToolOutput } +private enum OpenAIToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([OpenAIToolInvocationResult]) +} + private func resolveToolCalls( _ toolCalls: [OpenAIToolCall], session: LanguageModelSession -) async throws -> [OpenAIToolInvocationResult] { - if toolCalls.isEmpty { return [] } +) async throws -> OpenAIToolResolutionOutcome { + if toolCalls.isEmpty { return .invocations([]) } var toolsByName: [String: any Tool] = [:] for tool in session.tools { @@ -1491,44 +1520,97 @@ private func resolveToolCalls( } } - var results: [OpenAIToolInvocationResult] = [] - results.reserveCapacity(toolCalls.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(toolCalls.count) for call in toolCalls { guard let function = call.function else { continue } let args = try toGeneratedContent(function.arguments) let callID = call.id ?? UUID().uuidString - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: function.name, - arguments: args - ) - - guard let tool = toolsByName[function.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(function.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: function.name, - segments: [message] + arguments: args ) - results.append(OpenAIToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [OpenAIToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: callID, - toolName: tool.name, + id: call.id, + toolName: call.toolName, segments: segments ) - results.append(OpenAIToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(OpenAIToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(OpenAIToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: call.id, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(OpenAIToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } // MARK: - Converters diff --git a/Sources/AnyLanguageModel/ToolExecution.swift b/Sources/AnyLanguageModel/ToolExecution.swift new file mode 100644 index 0000000..3dd9743 --- /dev/null +++ b/Sources/AnyLanguageModel/ToolExecution.swift @@ -0,0 +1,94 @@ +/// A decision about how to handle a tool call. +/// +/// - Note: This API is exclusive to AnyLanguageModel +/// and using it means your code is no longer drop-in compatible +/// with the Foundation Models framework. +public enum ToolExecutionDecision: Sendable { + /// Execute the tool call using the associated tool. + case execute + + /// Stop the session after tool calls are generated without executing them. + case stop + + /// Provide tool output without executing the tool. + /// + /// Use this to supply results from an external system or cached responses. + case provideOutput([Transcript.Segment]) +} + +/// A delegate that observes and controls tool execution for a session. +/// +/// - Note: This API is exclusive to AnyLanguageModel +/// and using it means your code is no longer drop-in compatible +/// with the Foundation Models framework. +public protocol ToolExecutionDelegate: Sendable { + /// Notifies the delegate when the model generates tool calls. + /// + /// - Parameters: + /// - toolCalls: The tool calls produced by the model. + /// - session: The session that generated the tool calls. + func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async + + /// Asks the delegate how to handle a tool call. + /// + /// Return `.execute` to run the tool, `.stop` to halt after tool calls are generated, + /// or `.provideOutput` to supply output without executing the tool. + /// - Parameters: + /// - toolCall: The tool call to evaluate. + /// - session: The session requesting the decision. + func toolCallDecision(for toolCall: Transcript.ToolCall, in session: LanguageModelSession) async + -> ToolExecutionDecision + + /// Notifies the delegate after a tool call produces output. + /// + /// - Parameters: + /// - toolCall: The tool call that was handled. + /// - output: The output sent back to the model. + /// - session: The session that executed the tool call. + func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async + + /// Notifies the delegate when a tool call fails. + /// + /// - Parameters: + /// - toolCall: The tool call that failed. + /// - error: The underlying error raised during execution. + /// - session: The session that attempted the tool call. + func didFailToolCall( + _ toolCall: Transcript.ToolCall, + error: any Error, + in session: LanguageModelSession + ) async +} + +// MARK: - Default Implementations + +extension ToolExecutionDelegate { + /// Provides a default no-op implementation. + public func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async {} + + /// Provides a default decision that executes the tool call. + public func toolCallDecision( + for toolCall: Transcript.ToolCall, + in session: LanguageModelSession + ) async -> ToolExecutionDecision { + .execute + } + + /// Provides a default no-op implementation. + public func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async {} + + /// Provides a default no-op implementation. + public func didFailToolCall( + _ toolCall: Transcript.ToolCall, + error: any Error, + in session: LanguageModelSession + ) async {} +} diff --git a/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift b/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift index 66f263f..ece1ec8 100644 --- a/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift @@ -94,7 +94,8 @@ struct AnthropicLanguageModelTests { var foundToolOutput = false for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.id == "getWeather") + #expect(!toolOutput.id.isEmpty) + #expect(toolOutput.toolName == "getWeather") foundToolOutput = true } #expect(foundToolOutput) diff --git a/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift b/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift index 19a22bb..4caa0f7 100644 --- a/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift @@ -97,7 +97,8 @@ struct GeminiLanguageModelTests { var foundToolOutput = false for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.id == "getWeather") + #expect(!toolOutput.id.isEmpty) + #expect(toolOutput.toolName == "getWeather") foundToolOutput = true } #expect(foundToolOutput) diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index cf640e8..62827fd 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -82,7 +82,8 @@ import Testing var foundToolOutput = false for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.id == weatherTool.name) + #expect(!toolOutput.id.isEmpty) + #expect(toolOutput.toolName == weatherTool.name) foundToolOutput = true } #expect(foundToolOutput) diff --git a/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift b/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift index 50c4c43..234d17e 100644 --- a/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift @@ -93,7 +93,8 @@ struct OllamaLanguageModelTests { var foundToolOutput = false for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.id == weatherTool.name) + #expect(!toolOutput.id.isEmpty) + #expect(toolOutput.toolName == weatherTool.name) foundToolOutput = true } #expect(foundToolOutput) diff --git a/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift b/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift new file mode 100644 index 0000000..971692f --- /dev/null +++ b/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift @@ -0,0 +1,302 @@ +import Testing + +@testable import AnyLanguageModel + +private actor ToolExecutionDelegateSpy: ToolExecutionDelegate { + private(set) var generatedToolCalls: [Transcript.ToolCall] = [] + private(set) var executedToolCalls: [Transcript.ToolCall] = [] + private(set) var executedOutputs: [Transcript.ToolOutput] = [] + private(set) var failures: [any Error] = [] + + private let decisionProvider: @Sendable (Transcript.ToolCall) async -> ToolExecutionDecision + + init(decisionProvider: @escaping @Sendable (Transcript.ToolCall) async -> ToolExecutionDecision) { + self.decisionProvider = decisionProvider + } + + func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async { + generatedToolCalls.append(contentsOf: toolCalls) + } + + func toolCallDecision( + for toolCall: Transcript.ToolCall, + in session: LanguageModelSession + ) async -> ToolExecutionDecision { + await decisionProvider(toolCall) + } + + func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async { + executedToolCalls.append(toolCall) + executedOutputs.append(output) + } + + func didFailToolCall( + _ toolCall: Transcript.ToolCall, + error: any Error, + in session: LanguageModelSession + ) async { + failures.append(error) + } +} + +private struct ThrowingTool: Tool { + let name = "throwingTool" + let description = "A tool that throws" + @Generable + struct Arguments { + @Guide(description: "Ignored") + var message: String + } + func call(arguments: Arguments) async throws -> String { + throw ThrowingToolError.testError + } +} + +private enum ThrowingToolError: Error, Equatable { case testError } + +private struct ToolCallingTestModel: LanguageModel { + typealias UnavailableReason = Never + + let toolCalls: [Transcript.ToolCall] + let responseText: String + + init(toolCalls: [Transcript.ToolCall], responseText: String = "done") { + self.toolCalls = toolCalls + self.responseText = responseText + } + + func respond( + within session: LanguageModelSession, + to prompt: Prompt, + generating type: Content.Type, + includeSchemaInPrompt: Bool, + options: GenerationOptions + ) async throws -> LanguageModelSession.Response where Content: Generable { + guard type == String.self else { + fatalError("ToolCallingTestModel only supports generating String content") + } + + var entries: [Transcript.Entry] = [] + + if !toolCalls.isEmpty { + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(toolCalls, in: session) + } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(toolCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in toolCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + entries.append(.toolCalls(Transcript.ToolCalls(toolCalls))) + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + } + decisions.append(decision) + } + } else { + decisions = Array(repeating: .execute, count: toolCalls.count) + } + + entries.append(.toolCalls(Transcript.ToolCalls(toolCalls))) + + var toolsByName: [String: any Tool] = [:] + for tool in session.tools { + if toolsByName[tool.name] == nil { + toolsByName[tool.name] = tool + } + } + + for (index, call) in toolCalls.enumerated() { + switch decisions[index] { + case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. + entries = [.toolCalls(Transcript.ToolCalls(toolCalls))] + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .provideOutput(let segments): + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + entries.append(.toolOutput(output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + entries.append(.toolOutput(output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: call.id, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + entries.append(.toolOutput(output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } + } + } + } + + return LanguageModelSession.Response( + content: responseText as! Content, + rawContent: GeneratedContent(responseText), + transcriptEntries: ArraySlice(entries) + ) + } + + func streamResponse( + within session: LanguageModelSession, + to prompt: Prompt, + generating type: Content.Type, + includeSchemaInPrompt: Bool, + options: GenerationOptions + ) -> sending LanguageModelSession.ResponseStream where Content: Generable { + let rawContent = GeneratedContent(responseText) + return LanguageModelSession.ResponseStream(content: responseText as! Content, rawContent: rawContent) + } +} + +@Suite("ToolExecutionDelegate") +struct ToolExecutionDelegateTests { + @Test func stopAfterToolCalls() async throws { + let arguments = try GeneratedContent(json: #"{"city":"Cupertino"}"#) + let toolCall = Transcript.ToolCall(id: "call-1", toolName: WeatherTool().name, arguments: arguments) + let delegate = ToolExecutionDelegateSpy { _ in .stop } + let toolSpy = spy(on: WeatherTool()) + let session = LanguageModelSession( + model: ToolCallingTestModel(toolCalls: [toolCall]), + tools: [toolSpy] + ) + session.toolExecutionDelegate = delegate + + let response = try await session.respond(to: "Hi") + + #expect(response.content.isEmpty) + #expect( + response.transcriptEntries.contains { entry in + if case .toolCalls = entry { return true } + return false + } + ) + #expect( + !response.transcriptEntries.contains { entry in + if case .toolOutput = entry { return true } + return false + } + ) + + let calls = await toolSpy.calls + #expect(calls.isEmpty) + + let generatedCalls = await delegate.generatedToolCalls + #expect(generatedCalls == [toolCall]) + } + + @Test func provideOutputBypassesExecution() async throws { + let arguments = try GeneratedContent(json: #"{"city":"Cupertino"}"#) + let toolCall = Transcript.ToolCall(id: "call-2", toolName: WeatherTool().name, arguments: arguments) + let delegate = ToolExecutionDelegateSpy { _ in + .provideOutput([.text(.init(content: "Stubbed"))]) + } + let toolSpy = spy(on: WeatherTool()) + let session = LanguageModelSession( + model: ToolCallingTestModel(toolCalls: [toolCall]), + tools: [toolSpy] + ) + session.toolExecutionDelegate = delegate + + let response = try await session.respond(to: "Hi") + + #expect(!response.transcriptEntries.isEmpty) + #expect( + response.transcriptEntries.contains { entry in + if case let .toolOutput(output) = entry { + return output.segments.contains { segment in + if case .text(let text) = segment { return text.content == "Stubbed" } + return false + } + } + return false + } + ) + + let calls = await toolSpy.calls + #expect(calls.isEmpty) + + let executedOutputs = await delegate.executedOutputs + #expect(executedOutputs.count == 1) + } + + @Test func executeRunsToolAndNotifiesDelegate() async throws { + let arguments = try GeneratedContent(json: #"{"city":"Cupertino"}"#) + let toolCall = Transcript.ToolCall(id: "call-3", toolName: WeatherTool().name, arguments: arguments) + let delegate = ToolExecutionDelegateSpy { _ in .execute } + let toolSpy = spy(on: WeatherTool()) + let session = LanguageModelSession( + model: ToolCallingTestModel(toolCalls: [toolCall]), + tools: [toolSpy] + ) + session.toolExecutionDelegate = delegate + + _ = try await session.respond(to: "Hi") + + let calls = await toolSpy.calls + #expect(calls.count == 1) + + let executedCalls = await delegate.executedToolCalls + #expect(executedCalls.count == 1) + } + + @Test func didFailToolCallNotifiesDelegateWhenToolThrows() async throws { + let arguments = try GeneratedContent(json: #"{"message":"fail"}"#) + let toolCall = Transcript.ToolCall(id: "call-fail", toolName: ThrowingTool().name, arguments: arguments) + let delegate = ToolExecutionDelegateSpy { _ in .execute } + let session = LanguageModelSession( + model: ToolCallingTestModel(toolCalls: [toolCall]), + tools: [ThrowingTool()] + ) + session.toolExecutionDelegate = delegate + + _ = try? await session.respond(to: "Hi") + + let failures = await delegate.failures + #expect(failures.count == 1) + #expect((failures.first as? ThrowingToolError) == .testError) + } +}