Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,36 @@ let response = try await session.respond {
print(response.content)
```

To observe or control tool execution, assign a delegate on the session:

```swift
actor ToolExecutionObserver: ToolExecutionDelegate {
func didGenerateToolCalls(_ toolCalls: [Transcript.ToolCall], in session: LanguageModelSession) async {
print("Generated tool calls: \(toolCalls)")
}

func toolCallDecision(
for toolCall: Transcript.ToolCall,
in session: LanguageModelSession
) async -> ToolExecutionDecision {
// Return .stop to halt after tool calls, or .provideOutput(...) to bypass execution.
// This is a good place to ask the user for confirmation (for example, in a modal dialog).
.execute
}

func didExecuteToolCall(
_ toolCall: Transcript.ToolCall,
output: Transcript.ToolOutput,
in session: LanguageModelSession
) async {
print("Executed tool call: \(toolCall)")
}
}

let session = LanguageModelSession(model: model, tools: [WeatherTool()])
session.toolExecutionDelegate = ToolExecutionObserver()
```

## Features

### Supported Providers
Expand Down
10 changes: 10 additions & 0 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ public final class LanguageModelSession: @unchecked Sendable {
public let tools: [any Tool]
public let instructions: Instructions?

/// A delegate that observes and controls tool execution.
///
/// Set this property to intercept tool calls, provide custom output,
/// or stop after tool calls are generated.
///
/// - Note: This property is exclusive to AnyLanguageModel
/// and using it means your code is no longer drop-in compatible
/// with the Foundation Models framework.
@ObservationIgnored public var toolExecutionDelegate: (any ToolExecutionDelegate)?

@ObservationIgnored private let respondingState = RespondingState()

public convenience init(
Expand Down
132 changes: 102 additions & 30 deletions Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,23 @@ public struct AnthropicLanguageModel: LanguageModel {
}

if !toolUses.isEmpty {
let invocations = try await resolveToolUses(toolUses, session: session)
if !invocations.isEmpty {
entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call))))
for invocation in invocations {
entries.append(.toolOutput(invocation.output))
let resolution = try await resolveToolUses(toolUses, session: session)
switch resolution {
case .stop(let calls):
if !calls.isEmpty {
entries.append(.toolCalls(Transcript.ToolCalls(calls)))
}
return LanguageModelSession.Response(
content: "" as! Content,
rawContent: GeneratedContent(""),
transcriptEntries: ArraySlice(entries)
)
case .invocations(let invocations):
if !invocations.isEmpty {
entries.append(.toolCalls(Transcript.ToolCalls(invocations.map(\.call))))
for invocation in invocations {
entries.append(.toolOutput(invocation.output))
}
}
}
}
Expand Down Expand Up @@ -560,11 +572,16 @@ private struct ToolInvocationResult {
let output: Transcript.ToolOutput
}

private enum ToolResolutionOutcome {
case stop(calls: [Transcript.ToolCall])
case invocations([ToolInvocationResult])
}

private func resolveToolUses(
_ toolUses: [AnthropicToolUse],
session: LanguageModelSession
) async throws -> [ToolInvocationResult] {
if toolUses.isEmpty { return [] }
) async throws -> ToolResolutionOutcome {
if toolUses.isEmpty { return .invocations([]) }

var toolsByName: [String: any Tool] = [:]
for tool in session.tools {
Expand All @@ -573,43 +590,98 @@ private func resolveToolUses(
}
}

var results: [ToolInvocationResult] = []
results.reserveCapacity(toolUses.count)

var transcriptCalls: [Transcript.ToolCall] = []
transcriptCalls.reserveCapacity(toolUses.count)
for use in toolUses {
let args = try toGeneratedContent(use.input)
let callID = use.id
let transcriptCall = Transcript.ToolCall(
id: callID,
toolName: use.name,
arguments: args
)

guard let tool = toolsByName[use.name] else {
let message = Transcript.Segment.text(.init(content: "Tool not found: \(use.name)"))
let output = Transcript.ToolOutput(
transcriptCalls.append(
Transcript.ToolCall(
id: callID,
toolName: use.name,
segments: [message]
arguments: args
)
results.append(ToolInvocationResult(call: transcriptCall, output: output))
continue
)
}

if let delegate = session.toolExecutionDelegate {
await delegate.didGenerateToolCalls(transcriptCalls, in: session)
}

guard !transcriptCalls.isEmpty else { return .invocations([]) }

var decisions: [ToolExecutionDecision] = []
decisions.reserveCapacity(transcriptCalls.count)

if let delegate = session.toolExecutionDelegate {
for call in transcriptCalls {
let decision = await delegate.toolCallDecision(for: call, in: session)
if case .stop = decision {
return .stop(calls: transcriptCalls)
}
decisions.append(decision)
}
} else {
decisions = Array(repeating: .execute, count: transcriptCalls.count)
}

do {
let segments = try await tool.makeOutputSegments(from: args)
var results: [ToolInvocationResult] = []
results.reserveCapacity(transcriptCalls.count)

for (index, call) in transcriptCalls.enumerated() {
switch decisions[index] {
case .stop:
// This branch should be unreachable,
// because `.stop` returns during decision collection.
// Keep it as a defensive guard,
// in case that logic changes.
return .stop(calls: transcriptCalls)
case .provideOutput(let segments):
let output = Transcript.ToolOutput(
id: tool.name,
toolName: tool.name,
id: call.id,
toolName: call.toolName,
segments: segments
)
results.append(ToolInvocationResult(call: transcriptCall, output: output))
} catch {
throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error)
if let delegate = session.toolExecutionDelegate {
await delegate.didExecuteToolCall(call, output: output, in: session)
}
results.append(ToolInvocationResult(call: call, output: output))
case .execute:
guard let tool = toolsByName[call.toolName] else {
let message = Transcript.Segment.text(.init(content: "Tool not found: \(call.toolName)"))
let output = Transcript.ToolOutput(
id: call.id,
toolName: call.toolName,
segments: [message]
)
if let delegate = session.toolExecutionDelegate {
await delegate.didExecuteToolCall(call, output: output, in: session)
}
results.append(ToolInvocationResult(call: call, output: output))
continue
}

do {
let segments = try await tool.makeOutputSegments(from: call.arguments)
let output = Transcript.ToolOutput(
id: call.id,
toolName: tool.name,
segments: segments
)
if let delegate = session.toolExecutionDelegate {
await delegate.didExecuteToolCall(call, output: output, in: session)
}
results.append(ToolInvocationResult(call: call, output: output))
} catch {
if let delegate = session.toolExecutionDelegate {
await delegate.didFailToolCall(call, error: error, in: session)
}
throw LanguageModelSession.ToolCallError(tool: tool, underlyingError: error)
}
}
}

return results
return .invocations(results)
}

// Convert our GenerationSchema into Anthropic's expected JSON Schema payload
Expand Down
Loading