From 9cdf9cbd8a6ab37b78db044189ed4dca002eacff Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 03:01:09 -0800 Subject: [PATCH 1/9] Add tool execution delegate property to language model session --- .../LanguageModelSession.swift | 49 ++++ .../ToolExecutionDelegateTests.swift | 268 ++++++++++++++++++ 2 files changed, 317 insertions(+) create mode 100644 Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index 34f4ef8..cf435bb 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -1,6 +1,53 @@ import Foundation import Observation +/// A decision about how a tool call should be handled. +public enum ToolExecutionDecision: Sendable { + case execute + case stop + case provideOutput([Transcript.Segment]) +} + +/// A delegate that observes and controls tool execution for a session. +public protocol ToolExecutionDelegate: Sendable { + func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async + func toolCallDecision(for toolCall: Transcript.ToolCall, in session: LanguageModelSession) async + -> ToolExecutionDecision + func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async + func didFailToolCall( + _ toolCall: Transcript.ToolCall, + error: any Error, + in session: LanguageModelSession + ) async +} + +extension ToolExecutionDelegate { + public func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async {} + + public func toolCallDecision( + for toolCall: Transcript.ToolCall, + in session: LanguageModelSession + ) async -> ToolExecutionDecision { + .execute + } + + public func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async {} + + public func didFailToolCall( + _ toolCall: Transcript.ToolCall, + error: any Error, + in session: LanguageModelSession + ) async {} +} + @Observable public final class LanguageModelSession: @unchecked Sendable { public private(set) var isResponding: Bool = false @@ -9,6 +56,8 @@ public final class LanguageModelSession: @unchecked Sendable { private let model: any LanguageModel public let tools: [any Tool] public let instructions: Instructions? + /// An optional delegate that observes and controls tool execution. + @ObservationIgnored public var toolExecutionDelegate: (any ToolExecutionDelegate)? @ObservationIgnored private let respondingState = RespondingState() diff --git a/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift b/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift new file mode 100644 index 0000000..86b2a4b --- /dev/null +++ b/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift @@ -0,0 +1,268 @@ +import Testing + +@testable import AnyLanguageModel + +private actor ToolExecutionDelegateSpy: ToolExecutionDelegate { + private(set) var generatedToolCalls: [Transcript.ToolCall] = [] + private(set) var executedToolCalls: [Transcript.ToolCall] = [] + private(set) var executedOutputs: [Transcript.ToolOutput] = [] + private(set) var failures: [any Error] = [] + + private let decisionProvider: @Sendable (Transcript.ToolCall) async -> ToolExecutionDecision + + init(decisionProvider: @escaping @Sendable (Transcript.ToolCall) async -> ToolExecutionDecision) { + self.decisionProvider = decisionProvider + } + + func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async { + generatedToolCalls = toolCalls + } + + func toolCallDecision( + for toolCall: Transcript.ToolCall, + in session: LanguageModelSession + ) async -> ToolExecutionDecision { + await decisionProvider(toolCall) + } + + func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async { + executedToolCalls.append(toolCall) + executedOutputs.append(output) + } + + func didFailToolCall( + _ toolCall: Transcript.ToolCall, + error: any Error, + in session: LanguageModelSession + ) async { + failures.append(error) + } +} + +private struct ToolCallingTestModel: LanguageModel { + typealias UnavailableReason = Never + + let toolCalls: [Transcript.ToolCall] + let responseText: String + + init(toolCalls: [Transcript.ToolCall], responseText: String = "done") { + self.toolCalls = toolCalls + self.responseText = responseText + } + + func respond( + within session: LanguageModelSession, + to prompt: Prompt, + generating type: Content.Type, + includeSchemaInPrompt: Bool, + options: GenerationOptions + ) async throws -> LanguageModelSession.Response where Content: Generable { + guard type == String.self else { + fatalError("ToolCallingTestModel only supports generating String content") + } + + var entries: [Transcript.Entry] = [] + + if !toolCalls.isEmpty { + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(toolCalls, in: session) + } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(toolCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in toolCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + entries.append(.toolCalls(Transcript.ToolCalls(toolCalls))) + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + } + decisions.append(decision) + } + } else { + decisions = Array(repeating: .execute, count: toolCalls.count) + } + + entries.append(.toolCalls(Transcript.ToolCalls(toolCalls))) + + var toolsByName: [String: any Tool] = [:] + for tool in session.tools { + if toolsByName[tool.name] == nil { + toolsByName[tool.name] = tool + } + } + + for (index, call) in toolCalls.enumerated() { + switch decisions[index] { + case .stop: + entries = [.toolCalls(Transcript.ToolCalls(toolCalls))] + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .provideOutput(let segments): + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + entries.append(.toolOutput(output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + entries.append(.toolOutput(output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: call.id, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + entries.append(.toolOutput(output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } + } + } + } + + return LanguageModelSession.Response( + content: responseText as! Content, + rawContent: GeneratedContent(responseText), + transcriptEntries: ArraySlice(entries) + ) + } + + func streamResponse( + within session: LanguageModelSession, + to prompt: Prompt, + generating type: Content.Type, + includeSchemaInPrompt: Bool, + options: GenerationOptions + ) -> sending LanguageModelSession.ResponseStream where Content: Generable { + let rawContent = GeneratedContent(responseText) + return LanguageModelSession.ResponseStream(content: responseText as! Content, rawContent: rawContent) + } +} + +@Suite("ToolExecutionDelegate") +struct ToolExecutionDelegateTests { + @Test func stopAfterToolCalls() async throws { + let arguments = try GeneratedContent(json: #"{"city":"Cupertino"}"#) + let toolCall = Transcript.ToolCall(id: "call-1", toolName: WeatherTool().name, arguments: arguments) + let delegate = ToolExecutionDelegateSpy { _ in .stop } + let toolSpy = spy(on: WeatherTool()) + let session = LanguageModelSession( + model: ToolCallingTestModel(toolCalls: [toolCall]), + tools: [toolSpy] + ) + session.toolExecutionDelegate = delegate + + let response = try await session.respond(to: "Hi") + + #expect(response.content.isEmpty) + #expect( + response.transcriptEntries.contains { entry in + if case .toolCalls = entry { return true } + return false + } + ) + #expect( + !response.transcriptEntries.contains { entry in + if case .toolOutput = entry { return true } + return false + } + ) + + let calls = await toolSpy.calls + #expect(calls.isEmpty) + + let generatedCalls = await delegate.generatedToolCalls + #expect(generatedCalls == [toolCall]) + } + + @Test func provideOutputBypassesExecution() async throws { + let arguments = try GeneratedContent(json: #"{"city":"Cupertino"}"#) + let toolCall = Transcript.ToolCall(id: "call-2", toolName: WeatherTool().name, arguments: arguments) + let delegate = ToolExecutionDelegateSpy { _ in + .provideOutput([.text(.init(content: "Stubbed"))]) + } + let toolSpy = spy(on: WeatherTool()) + let session = LanguageModelSession( + model: ToolCallingTestModel(toolCalls: [toolCall]), + tools: [toolSpy] + ) + session.toolExecutionDelegate = delegate + + let response = try await session.respond(to: "Hi") + + #expect(!response.transcriptEntries.isEmpty) + #expect( + response.transcriptEntries.contains { entry in + if case let .toolOutput(output) = entry { + return output.segments.contains { segment in + if case .text(let text) = segment { return text.content == "Stubbed" } + return false + } + } + return false + } + ) + + let calls = await toolSpy.calls + #expect(calls.isEmpty) + + let executedOutputs = await delegate.executedOutputs + #expect(executedOutputs.count == 1) + } + + @Test func executeRunsToolAndNotifiesDelegate() async throws { + let arguments = try GeneratedContent(json: #"{"city":"Cupertino"}"#) + let toolCall = Transcript.ToolCall(id: "call-3", toolName: WeatherTool().name, arguments: arguments) + let delegate = ToolExecutionDelegateSpy { _ in .execute } + let toolSpy = spy(on: WeatherTool()) + let session = LanguageModelSession( + model: ToolCallingTestModel(toolCalls: [toolCall]), + tools: [toolSpy] + ) + session.toolExecutionDelegate = delegate + + _ = try await session.respond(to: "Hi") + + let calls = await toolSpy.calls + #expect(calls.count == 1) + + let executedCalls = await delegate.executedToolCalls + #expect(executedCalls.count == 1) + } +} From 6a385117db746103b34d3e1cd92a028e44a88ea7 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 03:01:45 -0800 Subject: [PATCH 2/9] Update models to use tool calling delegate --- .../Models/AnthropicLanguageModel.swift | 128 ++++++++++--- .../Models/GeminiLanguageModel.swift | 134 +++++++++---- .../Models/MLXLanguageModel.swift | 142 ++++++++++---- .../Models/OllamaLanguageModel.swift | 128 ++++++++++--- .../Models/OpenAILanguageModel.swift | 180 +++++++++++++----- 5 files changed, 532 insertions(+), 180 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index 509f5ef..118f6d1 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -356,11 +356,23 @@ public struct AnthropicLanguageModel: LanguageModel { } if !toolUses.isEmpty { - let invocations = try await resolveToolUses(toolUses, session: session) - if !invocations.isEmpty { - entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - for invocation in invocations { - entries.append(.toolOutput(invocation.output)) + let resolution = try await resolveToolUses(toolUses, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) + for invocation in invocations { + entries.append(.toolOutput(invocation.output)) + } } } } @@ -560,11 +572,16 @@ private struct ToolInvocationResult { let output: Transcript.ToolOutput } +private enum ToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([ToolInvocationResult]) +} + private func resolveToolUses( _ toolUses: [AnthropicToolUse], session: LanguageModelSession -) async throws -> [ToolInvocationResult] { - if toolUses.isEmpty { return [] } +) async throws -> ToolResolutionOutcome { + if toolUses.isEmpty { return .invocations([]) } var toolsByName: [String: any Tool] = [:] for tool in session.tools { @@ -573,43 +590,94 @@ private func resolveToolUses( } } - var results: [ToolInvocationResult] = [] - results.reserveCapacity(toolUses.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(toolUses.count) for use in toolUses { let args = try toGeneratedContent(use.input) let callID = use.id - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: use.name, - arguments: args - ) - - guard let tool = toolsByName[use.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(use.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: use.name, - segments: [message] + arguments: args ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [ToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: tool.name, - toolName: tool.name, + id: call.toolName, + toolName: call.toolName, segments: segments ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: tool.name, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } // Convert our GenerationSchema into Anthropic's expected JSON Schema payload diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index 886bbe2..673075a 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -313,17 +313,29 @@ public struct GeminiLanguageModel: LanguageModel { if !functionCalls.isEmpty { // Resolve function calls - let invocations = try await resolveFunctionCalls(functionCalls, session: session) - if !invocations.isEmpty { - transcript.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) + let resolution = try await resolveFunctionCalls(functionCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + transcript.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(transcript) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + transcript.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - for invocation in invocations { - transcript.append(.toolOutput(invocation.output)) + for invocation in invocations { + transcript.append(.toolOutput(invocation.output)) + } } - } - // Continue the loop to send the next request with tool results - continue + // Continue the loop to send the next request with tool results + continue + } } else { // No function calls, extract final text and return let text = @@ -530,11 +542,16 @@ private struct ToolInvocationResult { let output: Transcript.ToolOutput } +private enum ToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([ToolInvocationResult]) +} + private func resolveFunctionCalls( _ functionCalls: [GeminiFunctionCall], session: LanguageModelSession -) async throws -> [ToolInvocationResult] { - if functionCalls.isEmpty { return [] } +) async throws -> ToolResolutionOutcome { + if functionCalls.isEmpty { return .invocations([]) } var toolsByName: [String: any Tool] = [:] for tool in session.tools { @@ -543,43 +560,94 @@ private func resolveFunctionCalls( } } - var results: [ToolInvocationResult] = [] - results.reserveCapacity(functionCalls.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(functionCalls.count) for call in functionCalls { let args = try toGeneratedContent(call.args) let callID = UUID().uuidString - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: call.name, - arguments: args - ) - - guard let tool = toolsByName[call.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: call.name, - segments: [message] + arguments: args ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [ToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: tool.name, - toolName: tool.name, + id: call.toolName, + toolName: call.toolName, segments: segments ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: tool.name, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } private func convertToolToGeminiFormat(_ tool: any Tool) throws -> GeminiFunctionDeclaration { diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index a7ea74f..9f271bf 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -276,21 +276,33 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { - let invocations = try await resolveToolCalls(collectedToolCalls, session: session) - if !invocations.isEmpty { - allEntries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + allEntries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(allEntries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + allEntries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - // Execute each tool and add results to chat - for invocation in invocations { - allEntries.append(.toolOutput(invocation.output)) + // Execute each tool and add results to chat + for invocation in invocations { + allEntries.append(.toolOutput(invocation.output)) - // Convert tool output to JSON string for MLX - let toolResultJSON = toolOutputToJSON(invocation.output) - chat.append(.tool(toolResultJSON)) - } + // Convert tool output to JSON string for MLX + let toolResultJSON = toolOutputToJSON(invocation.output) + chat.append(.tool(toolResultJSON)) + } - // Continue loop to generate with tool results - continue + // Continue loop to generate with tool results + continue + } } } @@ -604,11 +616,16 @@ import Foundation let output: Transcript.ToolOutput } + private enum ToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([ToolInvocationResult]) + } + private func resolveToolCalls( _ toolCalls: [MLXLMCommon.ToolCall], session: LanguageModelSession - ) async throws -> [ToolInvocationResult] { - if toolCalls.isEmpty { return [] } + ) async throws -> ToolResolutionOutcome { + if toolCalls.isEmpty { return .invocations([]) } var toolsByName: [String: any Tool] = [:] for tool in session.tools { @@ -617,43 +634,94 @@ import Foundation } } - var results: [ToolInvocationResult] = [] - results.reserveCapacity(toolCalls.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(toolCalls.count) for call in toolCalls { let args = try toGeneratedContent(call.function.arguments) let callID = UUID().uuidString - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: call.function.name, - arguments: args - ) - - guard let tool = toolsByName[call.function.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.function.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: call.function.name, - segments: [message] + arguments: args ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [ToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: tool.name, - toolName: tool.name, + id: call.toolName, + toolName: call.toolName, segments: segments ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: tool.name, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } private func toGeneratedContent(_ args: [String: MLXLMCommon.JSONValue]) throws -> GeneratedContent { diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index 12ec53f..4a10134 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -112,11 +112,23 @@ public struct OllamaLanguageModel: LanguageModel { var entries: [Transcript.Entry] = [] if let toolCalls = chatResponse.message.toolCalls, !toolCalls.isEmpty { - let invocations = try await resolveToolCalls(toolCalls, session: session) - if !invocations.isEmpty { - entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) - for invocation in invocations { - entries.append(.toolOutput(invocation.output)) + let resolution = try await resolveToolCalls(toolCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call)))) + for invocation in invocations { + entries.append(.toolOutput(invocation.output)) + } } } } @@ -216,12 +228,17 @@ private struct ToolInvocationResult { let output: Transcript.ToolOutput } +private enum ToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([ToolInvocationResult]) +} + private func resolveToolCalls( _ toolCalls: [OllamaToolCall], session: LanguageModelSession -) async throws -> [ToolInvocationResult] { +) async throws -> ToolResolutionOutcome { if toolCalls.isEmpty { - return [] + return .invocations([]) } var toolsByName: [String: any Tool] = [:] @@ -231,43 +248,94 @@ private func resolveToolCalls( } } - var results: [ToolInvocationResult] = [] - results.reserveCapacity(toolCalls.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(toolCalls.count) for call in toolCalls { let args = try toGeneratedContent(call.function.arguments) let callID = call.id ?? UUID().uuidString - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: call.function.name, - arguments: args - ) - - guard let tool = toolsByName[call.function.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.function.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: call.function.name, - segments: [message] + arguments: args ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } - do { - let segments = try await tool.makeOutputSegments(from: args) + var results: [ToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) + + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: tool.name, - toolName: tool.name, + id: call.toolName, + toolName: call.toolName, segments: segments ) - results.append(ToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: tool.name, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(ToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } // MARK: - Conversions diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index 6299ce0..dfdf18d 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -508,20 +508,32 @@ public struct OpenAILanguageModel: LanguageModel { if let value = try? JSONValue(toolCallMessage) { messages.append(OpenAIMessage(role: .raw(rawContent: value), content: .text(""))) } - let invocations = try await resolveToolCalls(toolCalls, session: session) - if !invocations.isEmpty { - entries.append(.toolCalls(Transcript.ToolCalls(invocations.map { $0.call }))) - for invocation in invocations { - let output = invocation.output - entries.append(.toolOutput(output)) - messages.append( - OpenAIMessage( - role: .tool(id: invocation.call.id), - content: .text(convertSegmentsToToolContentString(output.segments)) + let resolution = try await resolveToolCalls(toolCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(invocations.map { $0.call }))) + for invocation in invocations { + let output = invocation.output + entries.append(.toolOutput(output)) + messages.append( + OpenAIMessage( + role: .tool(id: invocation.call.id), + content: .text(convertSegmentsToToolContentString(output.segments)) + ) ) - ) + } + continue } - continue } } @@ -575,21 +587,33 @@ public struct OpenAILanguageModel: LanguageModel { messages.append(OpenAIMessage(role: .raw(rawContent: msg), content: .text(""))) } } - let invocations = try await resolveToolCalls(toolCalls, session: session) - if !invocations.isEmpty { - entries.append(.toolCalls(Transcript.ToolCalls(invocations.map { $0.call }))) - - for invocation in invocations { - let output = invocation.output - entries.append(.toolOutput(output)) - messages.append( - OpenAIMessage( - role: .tool(id: invocation.call.id), - content: .text(convertSegmentsToToolContentString(output.segments)) + let resolution = try await resolveToolCalls(toolCalls, session: session) + switch resolution { + case .stop(let calls): + if !calls.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(calls))) + } + return LanguageModelSession.Response( + content: "" as! Content, + rawContent: GeneratedContent(""), + transcriptEntries: ArraySlice(entries) + ) + case .invocations(let invocations): + if !invocations.isEmpty { + entries.append(.toolCalls(Transcript.ToolCalls(invocations.map { $0.call }))) + + for invocation in invocations { + let output = invocation.output + entries.append(.toolOutput(output)) + messages.append( + OpenAIMessage( + role: .tool(id: invocation.call.id), + content: .text(convertSegmentsToToolContentString(output.segments)) + ) ) - ) + } + continue } - continue } } @@ -1478,11 +1502,16 @@ private struct OpenAIToolInvocationResult { let output: Transcript.ToolOutput } +private enum OpenAIToolResolutionOutcome { + case stop(calls: [Transcript.ToolCall]) + case invocations([OpenAIToolInvocationResult]) +} + private func resolveToolCalls( _ toolCalls: [OpenAIToolCall], session: LanguageModelSession -) async throws -> [OpenAIToolInvocationResult] { - if toolCalls.isEmpty { return [] } +) async throws -> OpenAIToolResolutionOutcome { + if toolCalls.isEmpty { return .invocations([]) } var toolsByName: [String: any Tool] = [:] for tool in session.tools { @@ -1491,44 +1520,95 @@ private func resolveToolCalls( } } - var results: [OpenAIToolInvocationResult] = [] - results.reserveCapacity(toolCalls.count) - + var transcriptCalls: [Transcript.ToolCall] = [] + transcriptCalls.reserveCapacity(toolCalls.count) for call in toolCalls { guard let function = call.function else { continue } let args = try toGeneratedContent(function.arguments) let callID = call.id ?? UUID().uuidString - let transcriptCall = Transcript.ToolCall( - id: callID, - toolName: function.name, - arguments: args - ) - - guard let tool = toolsByName[function.name] else { - let message = Transcript.Segment.text(.init(content: "Tool not found: \(function.name)")) - let output = Transcript.ToolOutput( + transcriptCalls.append( + Transcript.ToolCall( id: callID, toolName: function.name, - segments: [message] + arguments: args ) - results.append(OpenAIToolInvocationResult(call: transcriptCall, output: output)) - continue + ) + } + + if let delegate = session.toolExecutionDelegate { + await delegate.didGenerateToolCalls(transcriptCalls, in: session) + } + + guard !transcriptCalls.isEmpty else { return .invocations([]) } + + var decisions: [ToolExecutionDecision] = [] + decisions.reserveCapacity(transcriptCalls.count) + + if let delegate = session.toolExecutionDelegate { + for call in transcriptCalls { + let decision = await delegate.toolCallDecision(for: call, in: session) + if case .stop = decision { + return .stop(calls: transcriptCalls) + } + decisions.append(decision) } + } else { + decisions = Array(repeating: .execute, count: transcriptCalls.count) + } + + var results: [OpenAIToolInvocationResult] = [] + results.reserveCapacity(transcriptCalls.count) - do { - let segments = try await tool.makeOutputSegments(from: args) + for (index, call) in transcriptCalls.enumerated() { + switch decisions[index] { + case .stop: + return .stop(calls: transcriptCalls) + case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: callID, - toolName: tool.name, + id: call.id, + toolName: call.toolName, segments: segments ) - results.append(OpenAIToolInvocationResult(call: transcriptCall, output: output)) - } catch { - throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(OpenAIToolInvocationResult(call: call, output: output)) + case .execute: + guard let tool = toolsByName[call.toolName] else { + let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)")) + let output = Transcript.ToolOutput( + id: call.id, + toolName: call.toolName, + segments: [message] + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(OpenAIToolInvocationResult(call: call, output: output)) + continue + } + + do { + let segments = try await tool.makeOutputSegments(from: call.arguments) + let output = Transcript.ToolOutput( + id: call.id, + toolName: tool.name, + segments: segments + ) + if let delegate = session.toolExecutionDelegate { + await delegate.didExecuteToolCall(call, output: output, in: session) + } + results.append(OpenAIToolInvocationResult(call: call, output: output)) + } catch { + if let delegate = session.toolExecutionDelegate { + await delegate.didFailToolCall(call, error: error, in: session) + } + throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error) + } } } - return results + return .invocations(results) } // MARK: - Converters From 18ff52fcb19977c117c36e0eebb182a5ea5e0935 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 03:01:59 -0800 Subject: [PATCH 3/9] Document use of tool calling delegate in README --- README.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/README.md b/README.md index 99a0569..b8429ff 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,36 @@ let response = try await session.respond { print(response.content) ``` +To observe or control tool execution, assign a delegate on the session: + +```swift +actor ToolExecutionObserver: ToolExecutionDelegate { + func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async { + print("Generated tool calls: \(toolCalls)") + } + + func toolCallDecision( + for toolCall: Transcript.ToolCall, + in session: LanguageModelSession + ) async -> ToolExecutionDecision { + // Return .stop to halt after tool calls, or .provideOutput(...) to bypass execution. + // This is a good place to ask the user for confirmation (for example, in a modal dialog). + .execute + } + + func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async { + print("Executed tool call: \(toolCall)") + } +} + +let session = LanguageModelSession(model: model, tools: [WeatherTool()]) +session.toolExecutionDelegate = ToolExecutionObserver() +``` + ## Features ### Supported Providers From 9f89af62aee7ad348024802de68f94586c7f0302 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 03:04:57 -0800 Subject: [PATCH 4/9] Add documentation comments to new APIs --- .../LanguageModelSession.swift | 37 ++++++++++++++++++- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index cf435bb..cc7bec9 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -1,23 +1,49 @@ import Foundation import Observation -/// A decision about how a tool call should be handled. +/// A decision about how to handle a tool call. public enum ToolExecutionDecision: Sendable { + /// Execute the tool call using the associated tool. case execute + /// Stop the session after tool calls are generated without executing them. case stop + /// Provide tool output without executing the tool. + /// + /// Use this to supply results from an external system or cached responses. case provideOutput([Transcript.Segment]) } /// A delegate that observes and controls tool execution for a session. public protocol ToolExecutionDelegate: Sendable { + /// Notifies the delegate when the model generates tool calls. + /// - Parameters: + /// - toolCalls: The tool calls produced by the model. + /// - session: The session that generated the tool calls. func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async + /// Asks the delegate how to handle a tool call. + /// + /// Return `.execute` to run the tool, `.stop` to halt after tool calls are generated, + /// or `.provideOutput` to supply output without executing the tool. + /// - Parameters: + /// - toolCall: The tool call to evaluate. + /// - session: The session requesting the decision. func toolCallDecision(for toolCall: Transcript.ToolCall, in session: LanguageModelSession) async -> ToolExecutionDecision + /// Notifies the delegate after a tool call produces output. + /// - Parameters: + /// - toolCall: The tool call that was handled. + /// - output: The output sent back to the model. + /// - session: The session that executed the tool call. func didExecuteToolCall( _ toolCall: Transcript.ToolCall, output: Transcript.ToolOutput, in session: LanguageModelSession ) async + /// Notifies the delegate when a tool call fails. + /// - Parameters: + /// - toolCall: The tool call that failed. + /// - error: The underlying error raised during execution. + /// - session: The session that attempted the tool call. func didFailToolCall( _ toolCall: Transcript.ToolCall, error: any Error, @@ -26,8 +52,10 @@ public protocol ToolExecutionDelegate: Sendable { } extension ToolExecutionDelegate { + /// Provides a default no-op implementation. public func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async {} + /// Provides a default decision that executes the tool call. public func toolCallDecision( for toolCall: Transcript.ToolCall, in session: LanguageModelSession @@ -35,12 +63,14 @@ extension ToolExecutionDelegate { .execute } + /// Provides a default no-op implementation. public func didExecuteToolCall( _ toolCall: Transcript.ToolCall, output: Transcript.ToolOutput, in session: LanguageModelSession ) async {} + /// Provides a default no-op implementation. public func didFailToolCall( _ toolCall: Transcript.ToolCall, error: any Error, @@ -56,7 +86,10 @@ public final class LanguageModelSession: @unchecked Sendable { private let model: any LanguageModel public let tools: [any Tool] public let instructions: Instructions? - /// An optional delegate that observes and controls tool execution. + /// A delegate that observes and controls tool execution. + /// + /// Set this property to intercept tool calls, provide custom output, + /// or stop after tool calls are generated. @ObservationIgnored public var toolExecutionDelegate: (any ToolExecutionDelegate)? @ObservationIgnored private let respondingState = RespondingState() From a3c3b3131337b9ea2a4f384f5c1bf8e03e33a964 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 03:41:09 -0800 Subject: [PATCH 5/9] Fix tool name / id distinction in sources and tests --- Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift | 4 ++-- Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift | 4 ++-- Sources/AnyLanguageModel/Models/MLXLanguageModel.swift | 4 ++-- Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift | 4 ++-- Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift | 3 ++- Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift | 3 ++- Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift | 3 ++- Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift | 3 ++- 8 files changed, 16 insertions(+), 12 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index 118f6d1..1d6e160 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -634,7 +634,7 @@ private func resolveToolUses( return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: call.toolName, + id: call.id, toolName: call.toolName, segments: segments ) @@ -660,7 +660,7 @@ private func resolveToolUses( do { let segments = try await tool.makeOutputSegments(from: call.arguments) let output = Transcript.ToolOutput( - id: tool.name, + id: call.id, toolName: tool.name, segments: segments ) diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index 673075a..ce12be7 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -604,7 +604,7 @@ private func resolveFunctionCalls( return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: call.toolName, + id: call.id, toolName: call.toolName, segments: segments ) @@ -630,7 +630,7 @@ private func resolveFunctionCalls( do { let segments = try await tool.makeOutputSegments(from: call.arguments) let output = Transcript.ToolOutput( - id: tool.name, + id: call.id, toolName: tool.name, segments: segments ) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 9f271bf..cc85830 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -678,7 +678,7 @@ import Foundation return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: call.toolName, + id: call.id, toolName: call.toolName, segments: segments ) @@ -704,7 +704,7 @@ import Foundation do { let segments = try await tool.makeOutputSegments(from: call.arguments) let output = Transcript.ToolOutput( - id: tool.name, + id: call.id, toolName: tool.name, segments: segments ) diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index 4a10134..e8e5117 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -292,7 +292,7 @@ private func resolveToolCalls( return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( - id: call.toolName, + id: call.id, toolName: call.toolName, segments: segments ) @@ -318,7 +318,7 @@ private func resolveToolCalls( do { let segments = try await tool.makeOutputSegments(from: call.arguments) let output = Transcript.ToolOutput( - id: tool.name, + id: call.id, toolName: tool.name, segments: segments ) diff --git a/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift b/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift index 66f263f..ece1ec8 100644 --- a/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/AnthropicLanguageModelTests.swift @@ -94,7 +94,8 @@ struct AnthropicLanguageModelTests { var foundToolOutput = false for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.id == "getWeather") + #expect(!toolOutput.id.isEmpty) + #expect(toolOutput.toolName == "getWeather") foundToolOutput = true } #expect(foundToolOutput) diff --git a/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift b/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift index 19a22bb..4caa0f7 100644 --- a/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift @@ -97,7 +97,8 @@ struct GeminiLanguageModelTests { var foundToolOutput = false for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.id == "getWeather") + #expect(!toolOutput.id.isEmpty) + #expect(toolOutput.toolName == "getWeather") foundToolOutput = true } #expect(foundToolOutput) diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index cf640e8..62827fd 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -82,7 +82,8 @@ import Testing var foundToolOutput = false for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.id == weatherTool.name) + #expect(!toolOutput.id.isEmpty) + #expect(toolOutput.toolName == weatherTool.name) foundToolOutput = true } #expect(foundToolOutput) diff --git a/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift b/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift index 50c4c43..234d17e 100644 --- a/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift @@ -93,7 +93,8 @@ struct OllamaLanguageModelTests { var foundToolOutput = false for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.id == weatherTool.name) + #expect(!toolOutput.id.isEmpty) + #expect(toolOutput.toolName == weatherTool.name) foundToolOutput = true } #expect(foundToolOutput) From f956824a8dede415bae4c46b8b8bc9d69010e9d0 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 03:57:25 -0800 Subject: [PATCH 6/9] Add comments to explain defensive .stop handling --- Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift | 4 ++++ Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift | 2 ++ Sources/AnyLanguageModel/Models/MLXLanguageModel.swift | 2 ++ Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift | 2 ++ Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift | 2 ++ Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift | 2 ++ 6 files changed, 14 insertions(+) diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index 1d6e160..f88b68a 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -631,6 +631,10 @@ private func resolveToolUses( for (index, call) in transcriptCalls.enumerated() { switch decisions[index] { case .stop: + // This branch should be unreachable, + // because `.stop` returns during decision collection. + // Keep it as a defensive guard, + // in case that logic changes. return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index ce12be7..da91c80 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -601,6 +601,8 @@ private func resolveFunctionCalls( for (index, call) in transcriptCalls.enumerated() { switch decisions[index] { case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index cc85830..7b781b2 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -675,6 +675,8 @@ import Foundation for (index, call) in transcriptCalls.enumerated() { switch decisions[index] { case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index e8e5117..4f49d13 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -289,6 +289,8 @@ private func resolveToolCalls( for (index, call) in transcriptCalls.enumerated() { switch decisions[index] { case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index dfdf18d..da28255 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -1562,6 +1562,8 @@ private func resolveToolCalls( for (index, call) in transcriptCalls.enumerated() { switch decisions[index] { case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. return .stop(calls: transcriptCalls) case .provideOutput(let segments): let output = Transcript.ToolOutput( diff --git a/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift b/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift index 86b2a4b..58cceda 100644 --- a/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift +++ b/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift @@ -104,6 +104,8 @@ private struct ToolCallingTestModel: LanguageModel { for (index, call) in toolCalls.enumerated() { switch decisions[index] { case .stop: + // This branch should be unreachable because `.stop` returns during decision collection. + // Keep it as a defensive guard in case that logic changes. entries = [.toolCalls(Transcript.ToolCalls(toolCalls))] return LanguageModelSession.Response( content: "" as! Content, From b71e6e3b890b76cc7e67820655d19ed72c08f339 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 04:08:23 -0800 Subject: [PATCH 7/9] Extract tool execution code into its own file --- .../LanguageModelSession.swift | 77 ----------------- Sources/AnyLanguageModel/ToolExecution.swift | 86 +++++++++++++++++++ 2 files changed, 86 insertions(+), 77 deletions(-) create mode 100644 Sources/AnyLanguageModel/ToolExecution.swift diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index cc7bec9..7f136dc 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -1,83 +1,6 @@ import Foundation import Observation -/// A decision about how to handle a tool call. -public enum ToolExecutionDecision: Sendable { - /// Execute the tool call using the associated tool. - case execute - /// Stop the session after tool calls are generated without executing them. - case stop - /// Provide tool output without executing the tool. - /// - /// Use this to supply results from an external system or cached responses. - case provideOutput([Transcript.Segment]) -} - -/// A delegate that observes and controls tool execution for a session. -public protocol ToolExecutionDelegate: Sendable { - /// Notifies the delegate when the model generates tool calls. - /// - Parameters: - /// - toolCalls: The tool calls produced by the model. - /// - session: The session that generated the tool calls. - func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async - /// Asks the delegate how to handle a tool call. - /// - /// Return `.execute` to run the tool, `.stop` to halt after tool calls are generated, - /// or `.provideOutput` to supply output without executing the tool. - /// - Parameters: - /// - toolCall: The tool call to evaluate. - /// - session: The session requesting the decision. - func toolCallDecision(for toolCall: Transcript.ToolCall, in session: LanguageModelSession) async - -> ToolExecutionDecision - /// Notifies the delegate after a tool call produces output. - /// - Parameters: - /// - toolCall: The tool call that was handled. - /// - output: The output sent back to the model. - /// - session: The session that executed the tool call. - func didExecuteToolCall( - _ toolCall: Transcript.ToolCall, - output: Transcript.ToolOutput, - in session: LanguageModelSession - ) async - /// Notifies the delegate when a tool call fails. - /// - Parameters: - /// - toolCall: The tool call that failed. - /// - error: The underlying error raised during execution. - /// - session: The session that attempted the tool call. - func didFailToolCall( - _ toolCall: Transcript.ToolCall, - error: any Error, - in session: LanguageModelSession - ) async -} - -extension ToolExecutionDelegate { - /// Provides a default no-op implementation. - public func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async {} - - /// Provides a default decision that executes the tool call. - public func toolCallDecision( - for toolCall: Transcript.ToolCall, - in session: LanguageModelSession - ) async -> ToolExecutionDecision { - .execute - } - - /// Provides a default no-op implementation. - public func didExecuteToolCall( - _ toolCall: Transcript.ToolCall, - output: Transcript.ToolOutput, - in session: LanguageModelSession - ) async {} - - /// Provides a default no-op implementation. - public func didFailToolCall( - _ toolCall: Transcript.ToolCall, - error: any Error, - in session: LanguageModelSession - ) async {} -} - @Observable public final class LanguageModelSession: @unchecked Sendable { public private(set) var isResponding: Bool = false diff --git a/Sources/AnyLanguageModel/ToolExecution.swift b/Sources/AnyLanguageModel/ToolExecution.swift new file mode 100644 index 0000000..30931fc --- /dev/null +++ b/Sources/AnyLanguageModel/ToolExecution.swift @@ -0,0 +1,86 @@ +/// A decision about how to handle a tool call. +public enum ToolExecutionDecision: Sendable { + /// Execute the tool call using the associated tool. + case execute + + /// Stop the session after tool calls are generated without executing them. + case stop + + /// Provide tool output without executing the tool. + /// + /// Use this to supply results from an external system or cached responses. + case provideOutput([Transcript.Segment]) +} + +/// A delegate that observes and controls tool execution for a session. +public protocol ToolExecutionDelegate: Sendable { + /// Notifies the delegate when the model generates tool calls. + /// + /// - Parameters: + /// - toolCalls: The tool calls produced by the model. + /// - session: The session that generated the tool calls. + func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async + + /// Asks the delegate how to handle a tool call. + /// + /// Return `.execute` to run the tool, `.stop` to halt after tool calls are generated, + /// or `.provideOutput` to supply output without executing the tool. + /// - Parameters: + /// - toolCall: The tool call to evaluate. + /// - session: The session requesting the decision. + func toolCallDecision(for toolCall: Transcript.ToolCall, in session: LanguageModelSession) async + -> ToolExecutionDecision + + /// Notifies the delegate after a tool call produces output. + /// + /// - Parameters: + /// - toolCall: The tool call that was handled. + /// - output: The output sent back to the model. + /// - session: The session that executed the tool call. + func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async + + /// Notifies the delegate when a tool call fails. + /// + /// - Parameters: + /// - toolCall: The tool call that failed. + /// - error: The underlying error raised during execution. + /// - session: The session that attempted the tool call. + func didFailToolCall( + _ toolCall: Transcript.ToolCall, + error: any Error, + in session: LanguageModelSession + ) async +} + +// MARK: - Default Implementations + +extension ToolExecutionDelegate { + /// Provides a default no-op implementation. + public func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async {} + + /// Provides a default decision that executes the tool call. + public func toolCallDecision( + for toolCall: Transcript.ToolCall, + in session: LanguageModelSession + ) async -> ToolExecutionDecision { + .execute + } + + /// Provides a default no-op implementation. + public func didExecuteToolCall( + _ toolCall: Transcript.ToolCall, + output: Transcript.ToolOutput, + in session: LanguageModelSession + ) async {} + + /// Provides a default no-op implementation. + public func didFailToolCall( + _ toolCall: Transcript.ToolCall, + error: any Error, + in session: LanguageModelSession + ) async {} +} From 17ebacebae83b3a8ab1efaf3b7bb075098866ad7 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 04:15:57 -0800 Subject: [PATCH 8/9] Document custom features of AnyLanguageModel that break drop-in API compatibility --- Sources/AnyLanguageModel/LanguageModelSession.swift | 5 +++++ Sources/AnyLanguageModel/ToolExecution.swift | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index 7f136dc..1b56949 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -9,10 +9,15 @@ public final class LanguageModelSession: @unchecked Sendable { private let model: any LanguageModel public let tools: [any Tool] public let instructions: Instructions? + /// A delegate that observes and controls tool execution. /// /// Set this property to intercept tool calls, provide custom output, /// or stop after tool calls are generated. + /// + /// - Note: This property is exclusive to AnyLanguageModel + /// and using it means your code is no longer drop-in compatible + /// with the Foundation Models framework. @ObservationIgnored public var toolExecutionDelegate: (any ToolExecutionDelegate)? @ObservationIgnored private let respondingState = RespondingState() diff --git a/Sources/AnyLanguageModel/ToolExecution.swift b/Sources/AnyLanguageModel/ToolExecution.swift index 30931fc..3dd9743 100644 --- a/Sources/AnyLanguageModel/ToolExecution.swift +++ b/Sources/AnyLanguageModel/ToolExecution.swift @@ -1,4 +1,8 @@ /// A decision about how to handle a tool call. +/// +/// - Note: This API is exclusive to AnyLanguageModel +/// and using it means your code is no longer drop-in compatible +/// with the Foundation Models framework. public enum ToolExecutionDecision: Sendable { /// Execute the tool call using the associated tool. case execute @@ -13,6 +17,10 @@ public enum ToolExecutionDecision: Sendable { } /// A delegate that observes and controls tool execution for a session. +/// +/// - Note: This API is exclusive to AnyLanguageModel +/// and using it means your code is no longer drop-in compatible +/// with the Foundation Models framework. public protocol ToolExecutionDelegate: Sendable { /// Notifies the delegate when the model generates tool calls. /// From fb8d51eea50d3520d1e482f0cfd2952b078b166b Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 5 Feb 2026 04:43:53 -0800 Subject: [PATCH 9/9] Incorporate feedback from review --- .../ToolExecutionDelegateTests.swift | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift b/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift index 58cceda..971692f 100644 --- a/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift +++ b/Tests/AnyLanguageModelTests/ToolExecutionDelegateTests.swift @@ -15,7 +15,7 @@ private actor ToolExecutionDelegateSpy: ToolExecutionDelegate { } func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async { - generatedToolCalls = toolCalls + generatedToolCalls.append(contentsOf: toolCalls) } func toolCallDecision( @@ -43,6 +43,21 @@ private actor ToolExecutionDelegateSpy: ToolExecutionDelegate { } } +private struct ThrowingTool: Tool { + let name = "throwingTool" + let description = "A tool that throws" + @Generable + struct Arguments { + @Guide(description: "Ignored") + var message: String + } + func call(arguments: Arguments) async throws -> String { + throw ThrowingToolError.testError + } +} + +private enum ThrowingToolError: Error, Equatable { case testError } + private struct ToolCallingTestModel: LanguageModel { typealias UnavailableReason = Never @@ -267,4 +282,21 @@ struct ToolExecutionDelegateTests { let executedCalls = await delegate.executedToolCalls #expect(executedCalls.count == 1) } + + @Test func didFailToolCallNotifiesDelegateWhenToolThrows() async throws { + let arguments = try GeneratedContent(json: #"{"message":"fail"}"#) + let toolCall = Transcript.ToolCall(id: "call-fail", toolName: ThrowingTool().name, arguments: arguments) + let delegate = ToolExecutionDelegateSpy { _ in .execute } + let session = LanguageModelSession( + model: ToolCallingTestModel(toolCalls: [toolCall]), + tools: [ThrowingTool()] + ) + session.toolExecutionDelegate = delegate + + _ = try? await session.respond(to: "Hi") + + let failures = await delegate.failures + #expect(failures.count == 1) + #expect((failures.first as? ThrowingToolError) == .testError) + } }