diff --git a/ROADMAP.md b/ROADMAP.md index de975fb..2f6fdfe 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -39,15 +39,16 @@ - `AgentCardResolver` actor with TTL-based caching for multi-agent discovery - Updated A2AServer sample to use `A2AVapor` -## Short Term - -### A2AHummingbird Integration -- Add `A2AHummingbird` target for Hummingbird 2.0+ integration -- Same pattern as A2AVapor — separate product, no forced dependency +### v0.4.0 — Streaming Resilience +- SSE reconnection with configurable retry, exponential backoff, and jitter (`SSEConfiguration`) +- `SSELineParser` for proper SSE field parsing (`data:`, `id:`, `retry:`, `event:`) +- Server-side SSE `id:` and `retry:` field emission for reconnection support +- `Last-Event-ID` header on reconnect with event deduplication +- `ConnectionState` enum and `StreamingSession` type for connection health monitoring +- `sendStreamingMessageWithSession` / `subscribeToTaskWithSession` — rich streaming APIs +- Existing streaming methods unchanged (non-breaking) -### Client Enhancements -- SSE reconnection with automatic retry and last-event-id -- Connection health monitoring +## Short Term ## Medium Term diff --git a/Sources/A2A/A2A.docc/A2A.md b/Sources/A2A/A2A.docc/A2A.md index 95b3db8..6a893ad 100644 --- a/Sources/A2A/A2A.docc/A2A.md +++ b/Sources/A2A/A2A.docc/A2A.md @@ -83,6 +83,9 @@ let router = A2ARouter(handler: handler) - ``EventQueueManager`` - ``EventSubscription`` - ``StreamResponseSequence`` +- ``StreamingSession`` +- ``ConnectionState`` +- ``SSEConfiguration`` ### Request & Response Types diff --git a/Sources/A2A/Client/A2AClient.swift b/Sources/A2A/Client/A2AClient.swift index 3c1ff5c..1d6efda 100644 --- a/Sources/A2A/Client/A2AClient.swift +++ b/Sources/A2A/Client/A2AClient.swift @@ -45,6 +45,9 @@ public final class A2AClient: Sendable { /// JSON decoder configured for A2A. private let decoder: JSONDecoder + /// SSE streaming reconnection configuration. + private let sseConfiguration: SSEConfiguration + /// Auto-incrementing request ID. private let requestIdCounter = RequestIdCounter() @@ -52,12 +55,14 @@ public final class A2AClient: Sendable { baseURL: URL, session: URLSession = .shared, authHeaders: [String: String] = [:], - interceptors: [any A2AClientInterceptor] = [] + interceptors: [any A2AClientInterceptor] = [], + sseConfiguration: SSEConfiguration = .default ) { self.baseURL = baseURL self.session = session self.authHeaders = authHeaders self.interceptors = interceptors + self.sseConfiguration = sseConfiguration self.encoder = JSONEncoder() self.decoder = JSONDecoder() } @@ -85,7 +90,13 @@ public final class A2AClient: Sendable { /// Sends a streaming message and returns an AsyncSequence of stream responses. public func sendStreamingMessage(_ params: SendMessageRequest) async throws -> AsyncThrowingStream { - try await streamingCall(method: .sendStreamingMessage, params: params) + let session = try await streamingCallWithSession(method: .sendStreamingMessage, params: params) + return session.events + } + + /// Sends a streaming message and returns a ``StreamingSession`` with connection state monitoring. + public func sendStreamingMessageWithSession(_ params: SendMessageRequest) async throws -> StreamingSession { + try await streamingCallWithSession(method: .sendStreamingMessage, params: params) } // MARK: - Task Management @@ -107,7 +118,13 @@ public final class A2AClient: Sendable { /// Subscribes to task updates via SSE. public func subscribeToTask(_ params: SubscribeToTaskRequest) async throws -> AsyncThrowingStream { - try await streamingCall(method: .subscribeToTask, params: params) + let session = try await streamingCallWithSession(method: .subscribeToTask, params: params) + return session.events + } + + /// Subscribes to task updates and returns a ``StreamingSession`` with connection state monitoring. + public func subscribeToTaskWithSession(_ params: SubscribeToTaskRequest) async throws -> StreamingSession { + try await streamingCallWithSession(method: .subscribeToTask, params: params) } // MARK: - Push Notifications @@ -182,10 +199,10 @@ public final class A2AClient: Sendable { return result } - private func streamingCall( + private func streamingCallWithSession( method: A2AMethod, params: Params - ) async throws -> AsyncThrowingStream { + ) async throws -> StreamingSession { let id = requestIdCounter.next() let rpcRequest = JSONRPCRequest(id: .int(id), method: method, params: params) @@ -202,85 +219,207 @@ public final class A2AClient: Sendable { try await interceptor.before(request: &httpRequest, method: method) } + // Capture as immutable for use in closures + let baseRequest = httpRequest let decoder = self.decoder + let urlSession = self.session + let sseConfig = self.sseConfiguration + + let (connectionStateStream, connectionStateContinuation) = AsyncStream.makeStream() #if canImport(FoundationNetworking) // Linux: FoundationNetworking doesn't support URLSession.bytes, // so we fall back to fetching the complete response and parsing SSE lines. - let (data, response) = try await session.data(for: httpRequest) - try validateHTTPResponse(response) + // Reconnection is supported via retry on connection failure. + let events = AsyncThrowingStream { continuation in + let task = Task { + var parser = SSELineParser() + var attempt = 0 + + retryLoop: while true { + do { + var reconnectRequest = baseRequest + if let lastId = parser.lastEventId { + reconnectRequest.setValue(lastId, forHTTPHeaderField: "Last-Event-ID") + } - return AsyncThrowingStream { continuation in - do { - let text = String(data: data, encoding: .utf8) ?? "" - for line in text.components(separatedBy: .newlines) { - let trimmed = line.trimmingCharacters(in: .whitespaces) - guard trimmed.hasPrefix("data:") else { continue } - - let jsonString = String(trimmed.dropFirst(5)).trimmingCharacters(in: .whitespaces) - guard !jsonString.isEmpty else { continue } - guard let jsonData = jsonString.data(using: .utf8) else { continue } - - let rpcResponse = try decoder.decode(JSONRPCResponse.self, from: jsonData) - if let error = rpcResponse.error { - continuation.finish(throwing: A2AError( - code: A2AErrorCode(rawValue: error.code) ?? .internalError, - message: error.message, - data: error.data - )) + let (data, response) = try await urlSession.data(for: reconnectRequest) + try self.validateHTTPResponse(response) + + if attempt > 0 { + connectionStateContinuation.yield(.connected) + } + attempt = 0 + + let text = String(data: data, encoding: .utf8) ?? "" + for line in text.split(separator: "\n", omittingEmptySubsequences: false).map(String.init) { + let field = parser.parse(line: line) + switch field { + case .data(let jsonString): + guard !jsonString.isEmpty else { continue } + guard let jsonData = jsonString.data(using: .utf8) else { continue } + + let rpcResponse = try decoder.decode(JSONRPCResponse.self, from: jsonData) + if let error = rpcResponse.error { + continuation.finish(throwing: A2AError( + code: A2AErrorCode(rawValue: error.code) ?? .internalError, + message: error.message, + data: error.data + )) + connectionStateContinuation.finish() + return + } + if let result = rpcResponse.result { + continuation.yield(result) + } + default: + break + } + } + // Normal completion + continuation.finish() + connectionStateContinuation.finish() return - } - if let result = rpcResponse.result { - continuation.yield(result) + + } catch let error as A2AError { + // JSON-RPC errors are not retryable + continuation.finish(throwing: error) + connectionStateContinuation.yield(.disconnected(error)) + connectionStateContinuation.finish() + return + } catch { + guard !Task.isCancelled else { + continuation.finish(throwing: error) + connectionStateContinuation.finish() + return + } + + attempt += 1 + if attempt > sseConfig.maxRetries { + continuation.finish(throwing: error) + connectionStateContinuation.yield(.disconnected(error)) + connectionStateContinuation.finish() + return + } + + connectionStateContinuation.yield(.reconnecting(attempt: attempt, maxAttempts: sseConfig.maxRetries)) + let delay = parser.serverRetryInterval ?? sseConfig.delay(forAttempt: attempt - 1) + try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + continue retryLoop } } - continuation.finish() - } catch { - continuation.finish(throwing: error) + } + continuation.onTermination = { _ in + task.cancel() + connectionStateContinuation.finish() } } - #else - let (bytes, response) = try await session.bytes(for: httpRequest) - try validateHTTPResponse(response) - return AsyncThrowingStream { continuation in + #else + let events = AsyncThrowingStream { continuation in let task = Task { - do { - for try await line in bytes.lines { - guard !Task.isCancelled else { break } - - // SSE format: "data: {json}\n" - let trimmed = line.trimmingCharacters(in: .whitespaces) - guard trimmed.hasPrefix("data:") else { continue } - - let jsonString = String(trimmed.dropFirst(5)).trimmingCharacters(in: .whitespaces) - guard !jsonString.isEmpty else { continue } - guard let jsonData = jsonString.data(using: .utf8) else { continue } - - // Parse JSON-RPC response wrapping the StreamResponse - let rpcResponse = try decoder.decode(JSONRPCResponse.self, from: jsonData) - if let error = rpcResponse.error { - continuation.finish(throwing: A2AError( - code: A2AErrorCode(rawValue: error.code) ?? .internalError, - message: error.message, - data: error.data - )) + var parser = SSELineParser() + var lastSeenId: Int? + var attempt = 0 + + retryLoop: while true { + do { + guard !Task.isCancelled else { + continuation.finish() + connectionStateContinuation.finish() return } - if let result = rpcResponse.result { - continuation.yield(result) + + var reconnectRequest = baseRequest + if let lastId = parser.lastEventId { + reconnectRequest.setValue(lastId, forHTTPHeaderField: "Last-Event-ID") + } + + let (bytes, response) = try await urlSession.bytes(for: reconnectRequest) + try self.validateHTTPResponse(response) + + if attempt > 0 { + connectionStateContinuation.yield(.connected) + } + attempt = 0 + + for try await line in bytes.lines { + guard !Task.isCancelled else { break } + + let field = parser.parse(line: line) + switch field { + case .data(let jsonString): + guard !jsonString.isEmpty else { continue } + guard let jsonData = jsonString.data(using: .utf8) else { continue } + + // Deduplicate after reconnect + if let lastId = parser.lastEventId, let idNum = Int(lastId), + let seen = lastSeenId, idNum <= seen { + continue + } + if let lastId = parser.lastEventId, let idNum = Int(lastId) { + lastSeenId = idNum + } + + let rpcResponse = try decoder.decode(JSONRPCResponse.self, from: jsonData) + if let error = rpcResponse.error { + continuation.finish(throwing: A2AError( + code: A2AErrorCode(rawValue: error.code) ?? .internalError, + message: error.message, + data: error.data + )) + connectionStateContinuation.finish() + return + } + if let result = rpcResponse.result { + continuation.yield(result) + } + default: + break + } } + // Normal completion + continuation.finish() + connectionStateContinuation.finish() + return + + } catch let error as A2AError { + // JSON-RPC errors are not retryable + continuation.finish(throwing: error) + connectionStateContinuation.yield(.disconnected(error)) + connectionStateContinuation.finish() + return + } catch { + guard !Task.isCancelled else { + continuation.finish(throwing: error) + connectionStateContinuation.finish() + return + } + + attempt += 1 + if attempt > sseConfig.maxRetries { + continuation.finish(throwing: error) + connectionStateContinuation.yield(.disconnected(error)) + connectionStateContinuation.finish() + return + } + + connectionStateContinuation.yield(.reconnecting(attempt: attempt, maxAttempts: sseConfig.maxRetries)) + let delay = parser.serverRetryInterval ?? sseConfig.delay(forAttempt: attempt - 1) + try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + continue retryLoop } - continuation.finish() - } catch { - continuation.finish(throwing: error) } } continuation.onTermination = { _ in task.cancel() + connectionStateContinuation.finish() } } #endif + + connectionStateContinuation.yield(.connected) + return StreamingSession(events: events, connectionState: connectionStateStream) } private func applyHeaders(to request: inout URLRequest) { diff --git a/Sources/A2A/Client/ConnectionState.swift b/Sources/A2A/Client/ConnectionState.swift new file mode 100644 index 0000000..13c6bb2 --- /dev/null +++ b/Sources/A2A/Client/ConnectionState.swift @@ -0,0 +1,16 @@ +import Foundation + +/// Represents the state of an SSE streaming connection. +/// +/// Use with ``StreamingSession/connectionState`` to monitor connection health +/// during streaming operations. +public enum ConnectionState: Sendable { + /// The connection is active and receiving events. + case connected + + /// The connection dropped and is being re-established. + case reconnecting(attempt: Int, maxAttempts: Int) + + /// The connection has been permanently lost. + case disconnected(any Error) +} diff --git a/Sources/A2A/Client/SSEConfiguration.swift b/Sources/A2A/Client/SSEConfiguration.swift new file mode 100644 index 0000000..26d2931 --- /dev/null +++ b/Sources/A2A/Client/SSEConfiguration.swift @@ -0,0 +1,65 @@ +import Foundation + +/// Configuration for SSE streaming reconnection behavior. +/// +/// When a streaming connection drops unexpectedly, the client can automatically +/// retry with exponential backoff. Use ``default`` for sensible defaults or +/// ``disabled`` to opt out of reconnection. +/// +/// ```swift +/// // Default reconnection (3 retries with exponential backoff) +/// let client = A2AClient(baseURL: url) +/// +/// // Custom configuration +/// let client = A2AClient( +/// baseURL: url, +/// sseConfiguration: SSEConfiguration(maxRetries: 5, initialRetryInterval: 2.0) +/// ) +/// +/// // Disable reconnection +/// let client = A2AClient(baseURL: url, sseConfiguration: .disabled) +/// ``` +public struct SSEConfiguration: Sendable, Equatable { + /// Maximum number of reconnection attempts before giving up. + public var maxRetries: Int + + /// Initial delay between reconnection attempts in seconds. + public var initialRetryInterval: TimeInterval + + /// Maximum delay between reconnection attempts in seconds. + public var maxRetryInterval: TimeInterval + + /// Multiplier applied to the retry interval after each failed attempt. + public var backoffMultiplier: Double + + /// Fraction of the retry interval to use as random jitter (0.0–1.0). + public var jitterFraction: Double + + public init( + maxRetries: Int = 3, + initialRetryInterval: TimeInterval = 1.0, + maxRetryInterval: TimeInterval = 30.0, + backoffMultiplier: Double = 2.0, + jitterFraction: Double = 0.1 + ) { + self.maxRetries = maxRetries + self.initialRetryInterval = initialRetryInterval + self.maxRetryInterval = maxRetryInterval + self.backoffMultiplier = backoffMultiplier + self.jitterFraction = jitterFraction + } + + /// Default configuration with 3 retries and exponential backoff. + public static let `default` = SSEConfiguration() + + /// Disabled reconnection — errors are thrown immediately. + public static let disabled = SSEConfiguration(maxRetries: 0) + + /// Calculates the delay for a given retry attempt, incorporating backoff and jitter. + internal func delay(forAttempt attempt: Int) -> TimeInterval { + let base = initialRetryInterval * pow(backoffMultiplier, Double(attempt)) + let clamped = min(base, maxRetryInterval) + let jitter = clamped * jitterFraction * Double.random(in: -1...1) + return max(0, clamped + jitter) + } +} diff --git a/Sources/A2A/Client/SSELineParser.swift b/Sources/A2A/Client/SSELineParser.swift new file mode 100644 index 0000000..6af07e2 --- /dev/null +++ b/Sources/A2A/Client/SSELineParser.swift @@ -0,0 +1,62 @@ +import Foundation + +/// Internal parser for Server-Sent Events (SSE) lines. +/// +/// Handles the SSE protocol fields: `data:`, `id:`, `retry:`, and `event:`. +/// Tracks the last event ID and server-suggested retry interval for reconnection. +struct SSELineParser: Sendable { + /// The last received event ID, used for `Last-Event-ID` header on reconnect. + private(set) var lastEventId: String? + + /// Server-suggested retry interval in seconds, if received. + private(set) var serverRetryInterval: TimeInterval? + + enum Field: Sendable, Equatable { + case data(String) + case id(String) + case retry(Int) + case event(String) + case comment + case empty + } + + /// Parses a single SSE line and returns the field type. + mutating func parse(line: String) -> Field { + let trimmed = line.trimmingCharacters(in: .whitespaces) + + if trimmed.isEmpty { + return .empty + } + + if trimmed.hasPrefix(":") { + return .comment + } + + if trimmed.hasPrefix("data:") { + let value = String(trimmed.dropFirst(5)).trimmingCharacters(in: .whitespaces) + return .data(value) + } + + if trimmed.hasPrefix("id:") { + let value = String(trimmed.dropFirst(3)).trimmingCharacters(in: .whitespaces) + lastEventId = value + return .id(value) + } + + if trimmed.hasPrefix("retry:") { + let value = String(trimmed.dropFirst(6)).trimmingCharacters(in: .whitespaces) + if let ms = Int(value) { + serverRetryInterval = TimeInterval(ms) / 1000.0 + return .retry(ms) + } + return .comment + } + + if trimmed.hasPrefix("event:") { + let value = String(trimmed.dropFirst(6)).trimmingCharacters(in: .whitespaces) + return .event(value) + } + + return .comment + } +} diff --git a/Sources/A2A/Client/StreamingSession.swift b/Sources/A2A/Client/StreamingSession.swift new file mode 100644 index 0000000..44b622d --- /dev/null +++ b/Sources/A2A/Client/StreamingSession.swift @@ -0,0 +1,37 @@ +import Foundation + +/// A streaming session that provides both A2A events and connection state updates. +/// +/// Use this type to monitor connection health during streaming operations. +/// Obtain a `StreamingSession` via ``A2AClient/sendStreamingMessageWithSession(_:)`` +/// or ``A2AClient/subscribeToTaskWithSession(_:)``. +/// +/// ```swift +/// let session = try await client.sendStreamingMessageWithSession(request) +/// +/// // Monitor connection state in a separate task +/// Task { +/// for await state in session.connectionState { +/// switch state { +/// case .connected: +/// print("Connected") +/// case .reconnecting(let attempt, let max): +/// print("Reconnecting (\(attempt)/\(max))...") +/// case .disconnected(let error): +/// print("Disconnected: \(error)") +/// } +/// } +/// } +/// +/// // Consume events +/// for try await event in session.events { +/// // handle event +/// } +/// ``` +public struct StreamingSession: Sendable { + /// The stream of A2A events. + public let events: AsyncThrowingStream + + /// Connection state changes during the streaming session. + public let connectionState: AsyncStream +} diff --git a/Sources/A2A/Server/A2AServer.swift b/Sources/A2A/Server/A2AServer.swift index 059b775..c3333ed 100644 --- a/Sources/A2A/Server/A2AServer.swift +++ b/Sources/A2A/Server/A2AServer.swift @@ -232,14 +232,24 @@ public struct A2ARouter: Sendable { return AsyncThrowingStream { continuation in let task = Task { + var eventCounter = 0 do { for try await event in stream { guard !Task.isCancelled else { break } + eventCounter += 1 let response = JSONRPCResponse(id: resolvedId, result: event) let jsonData = try encoder.encode(response) guard let jsonString = String(data: jsonData, encoding: .utf8) else { continue } - let sseData = "data: \(jsonString)\n\n".data(using: .utf8)! - continuation.yield(sseData) + + var sse = "" + if eventCounter == 1 { + sse += "retry: 3000\n" + } + sse += "id: \(eventCounter)\n" + sse += "data: \(jsonString)\n" + sse += "\n" + + continuation.yield(sse.data(using: .utf8)!) } continuation.finish() } catch let error as A2AError { @@ -249,7 +259,8 @@ public struct A2ARouter: Sendable { ) if let errorData = try? encoder.encode(errorResponse), let errorString = String(data: errorData, encoding: .utf8) { - let sseData = "data: \(errorString)\n\n".data(using: .utf8)! + eventCounter += 1 + let sseData = "id: \(eventCounter)\ndata: \(errorString)\n\n".data(using: .utf8)! continuation.yield(sseData) } continuation.finish(throwing: error) diff --git a/Tests/A2ATests/SSEConfigurationTests.swift b/Tests/A2ATests/SSEConfigurationTests.swift new file mode 100644 index 0000000..09e5160 --- /dev/null +++ b/Tests/A2ATests/SSEConfigurationTests.swift @@ -0,0 +1,72 @@ +import Testing +import Foundation +@testable import A2A + +@Suite("SSEConfiguration") +struct SSEConfigurationTests { + + @Test func defaultConfiguration() { + let config = SSEConfiguration.default + #expect(config.maxRetries == 3) + #expect(config.initialRetryInterval == 1.0) + #expect(config.maxRetryInterval == 30.0) + #expect(config.backoffMultiplier == 2.0) + #expect(config.jitterFraction == 0.1) + } + + @Test func disabledConfiguration() { + let config = SSEConfiguration.disabled + #expect(config.maxRetries == 0) + } + + @Test func delayExponentialBackoff() { + let config = SSEConfiguration( + initialRetryInterval: 1.0, + maxRetryInterval: 30.0, + backoffMultiplier: 2.0, + jitterFraction: 0.0 // No jitter for deterministic testing + ) + + let delay0 = config.delay(forAttempt: 0) + #expect(delay0 == 1.0) // 1.0 * 2^0 = 1.0 + + let delay1 = config.delay(forAttempt: 1) + #expect(delay1 == 2.0) // 1.0 * 2^1 = 2.0 + + let delay2 = config.delay(forAttempt: 2) + #expect(delay2 == 4.0) // 1.0 * 2^2 = 4.0 + } + + @Test func delayClampsToMax() { + let config = SSEConfiguration( + initialRetryInterval: 1.0, + maxRetryInterval: 5.0, + backoffMultiplier: 2.0, + jitterFraction: 0.0 + ) + + let delay10 = config.delay(forAttempt: 10) + #expect(delay10 == 5.0) // Clamped to maxRetryInterval + } + + @Test func delayWithJitterInRange() { + let config = SSEConfiguration( + initialRetryInterval: 10.0, + maxRetryInterval: 30.0, + backoffMultiplier: 1.0, + jitterFraction: 0.1 + ) + + // With jitter=0.1, delay should be in range [9.0, 11.0] + for _ in 0..<100 { + let delay = config.delay(forAttempt: 0) + #expect(delay >= 9.0) + #expect(delay <= 11.0) + } + } + + @Test func equatable() { + #expect(SSEConfiguration.default == SSEConfiguration()) + #expect(SSEConfiguration.default != SSEConfiguration.disabled) + } +} diff --git a/Tests/A2ATests/SSELineParserTests.swift b/Tests/A2ATests/SSELineParserTests.swift new file mode 100644 index 0000000..ab0f7e1 --- /dev/null +++ b/Tests/A2ATests/SSELineParserTests.swift @@ -0,0 +1,97 @@ +import Testing +@testable import A2A + +@Suite("SSELineParser") +struct SSELineParserTests { + + @Test func parsesDataField() { + var parser = SSELineParser() + let field = parser.parse(line: "data: {\"hello\":\"world\"}") + #expect(field == .data("{\"hello\":\"world\"}")) + } + + @Test func parsesDataFieldNoSpace() { + var parser = SSELineParser() + let field = parser.parse(line: "data:{\"hello\":\"world\"}") + #expect(field == .data("{\"hello\":\"world\"}")) + } + + @Test func parsesEmptyDataField() { + var parser = SSELineParser() + let field = parser.parse(line: "data:") + #expect(field == .data("")) + } + + @Test func parsesIdField() { + var parser = SSELineParser() + let field = parser.parse(line: "id: 42") + #expect(field == .id("42")) + #expect(parser.lastEventId == "42") + } + + @Test func tracksLastEventId() { + var parser = SSELineParser() + _ = parser.parse(line: "id: 1") + _ = parser.parse(line: "id: 2") + _ = parser.parse(line: "id: 3") + #expect(parser.lastEventId == "3") + } + + @Test func parsesRetryField() { + var parser = SSELineParser() + let field = parser.parse(line: "retry: 3000") + #expect(field == .retry(3000)) + #expect(parser.serverRetryInterval == 3.0) + } + + @Test func parsesRetryFieldUpdatesInterval() { + var parser = SSELineParser() + _ = parser.parse(line: "retry: 1000") + #expect(parser.serverRetryInterval == 1.0) + _ = parser.parse(line: "retry: 5000") + #expect(parser.serverRetryInterval == 5.0) + } + + @Test func invalidRetryTreatedAsComment() { + var parser = SSELineParser() + let field = parser.parse(line: "retry: abc") + #expect(field == .comment) + #expect(parser.serverRetryInterval == nil) + } + + @Test func parsesEventField() { + var parser = SSELineParser() + let field = parser.parse(line: "event: message") + #expect(field == .event("message")) + } + + @Test func parsesComment() { + var parser = SSELineParser() + let field = parser.parse(line: ": this is a comment") + #expect(field == .comment) + } + + @Test func parsesEmptyLine() { + var parser = SSELineParser() + let field = parser.parse(line: "") + #expect(field == .empty) + } + + @Test func parsesWhitespaceOnlyLine() { + var parser = SSELineParser() + let field = parser.parse(line: " ") + #expect(field == .empty) + } + + @Test func unknownFieldTreatedAsComment() { + var parser = SSELineParser() + let field = parser.parse(line: "unknown: value") + #expect(field == .comment) + } + + @Test func initialStateIsNil() { + let parser = SSELineParser() + #expect(parser.lastEventId == nil) + #expect(parser.serverRetryInterval == nil) + } +} diff --git a/Tests/A2ATests/SSEStreamingTests.swift b/Tests/A2ATests/SSEStreamingTests.swift new file mode 100644 index 0000000..08cde26 --- /dev/null +++ b/Tests/A2ATests/SSEStreamingTests.swift @@ -0,0 +1,202 @@ +import Testing +import Foundation +@testable import A2A + +/// Handler that supports streaming for SSE format tests. +struct SSETestHandler: A2AAgentHandler { + func agentCard() async throws -> AgentCard { + AgentCard( + name: "SSE Test Agent", + description: "Test", + supportedInterfaces: [AgentInterface(url: "http://localhost:8080")], + version: "1.0.0", + skills: [AgentSkill(id: "test", name: "Test", description: "Test", tags: [])] + ) + } + + func handleSendMessage(_ request: SendMessageRequest) async throws -> SendMessageResponse { + throw A2AError.unsupportedOperation("Use streaming") + } + + func handleSendStreamingMessage(_ request: SendMessageRequest) async throws -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = A2ATask( + id: "task-1", + contextId: "ctx-1", + status: TaskStatus(state: .working) + ) + continuation.yield(.task(task)) + continuation.yield(.statusUpdate(TaskStatusUpdateEvent( + taskId: "task-1", + contextId: "ctx-1", + status: TaskStatus(state: .completed, message: Message(role: .agent, parts: [.text("Done")])) + ))) + continuation.finish() + } + } + + func handleGetTask(_ request: GetTaskRequest) async throws -> A2ATask { + throw A2AError(code: .taskNotFound) + } + + func handleCancelTask(_ request: CancelTaskRequest) async throws -> A2ATask { + throw A2AError(code: .taskNotFound) + } +} + +@Suite("SSE Streaming Format") +struct SSEStreamingTests { + let encoder = JSONEncoder() + let decoder = JSONDecoder() + + @Test func streamEmitsIdAndRetryFields() async throws { + let router = A2ARouter(handler: SSETestHandler()) + let request = JSONRPCRequest( + id: .int(1), + method: .sendStreamingMessage, + params: SendMessageRequest( + message: Message(role: .user, parts: [.text("Hello")]) + ) + ) + let body = try encoder.encode(request) + let result = try await router.route(body: body) + + guard case .stream(let stream) = result else { + Issue.record("Expected .stream result") + return + } + + var chunks: [String] = [] + for try await data in stream { + if let str = String(data: data, encoding: .utf8) { + chunks.append(str) + } + } + + #expect(chunks.count == 2) + + // First chunk should have retry: field + let first = chunks[0] + #expect(first.contains("retry: 3000")) + #expect(first.contains("id: 1")) + #expect(first.contains("data: ")) + + // Second chunk should have id but no retry + let second = chunks[1] + #expect(!second.contains("retry:")) + #expect(second.contains("id: 2")) + #expect(second.contains("data: ")) + } + + @Test func streamDataIsValidJSONRPC() async throws { + let router = A2ARouter(handler: SSETestHandler()) + let request = JSONRPCRequest( + id: .int(42), + method: .sendStreamingMessage, + params: SendMessageRequest( + message: Message(role: .user, parts: [.text("Hello")]) + ) + ) + let body = try encoder.encode(request) + let result = try await router.route(body: body) + + guard case .stream(let stream) = result else { + Issue.record("Expected .stream result") + return + } + + var parser = SSELineParser() + var responses: [JSONRPCResponse] = [] + + for try await data in stream { + let text = String(data: data, encoding: .utf8) ?? "" + for line in text.split(separator: "\n", omittingEmptySubsequences: false).map(String.init) { + let field = parser.parse(line: line) + if case .data(let jsonString) = field, !jsonString.isEmpty { + let jsonData = jsonString.data(using: .utf8)! + let rpcResponse = try decoder.decode(JSONRPCResponse.self, from: jsonData) + responses.append(rpcResponse) + } + } + } + + #expect(responses.count == 2) + #expect(responses[0].isSuccess) + #expect(responses[1].isSuccess) + + // First should be task, second should be status update + if case .task(let task) = responses[0].result { + #expect(task.id == "task-1") + #expect(task.status.state == .working) + } else { + Issue.record("Expected .task response") + } + + if case .statusUpdate(let update) = responses[1].result { + #expect(update.status.state == .completed) + } else { + Issue.record("Expected .statusUpdate response") + } + + // Parser should have tracked event IDs + #expect(parser.lastEventId == "2") + #expect(parser.serverRetryInterval == 3.0) + } + + @Test func sseChunksEndWithDoubleNewline() async throws { + let router = A2ARouter(handler: SSETestHandler()) + let request = JSONRPCRequest( + id: .int(1), + method: .sendStreamingMessage, + params: SendMessageRequest( + message: Message(role: .user, parts: [.text("Hello")]) + ) + ) + let body = try encoder.encode(request) + let result = try await router.route(body: body) + + guard case .stream(let stream) = result else { + Issue.record("Expected .stream result") + return + } + + for try await data in stream { + let text = String(data: data, encoding: .utf8) ?? "" + // Each SSE event must end with \n\n per spec + #expect(text.hasSuffix("\n\n"), "SSE chunk must end with double newline") + } + } + + @Test func streamingSessionStructure() async throws { + // Test that StreamingSession correctly passes through events and state + let (eventStream, eventContinuation) = AsyncThrowingStream.makeStream() + let (stateStream, stateContinuation) = AsyncStream.makeStream() + + let session = StreamingSession(events: eventStream, connectionState: stateStream) + + // Send some state and finish + stateContinuation.yield(.connected) + stateContinuation.yield(.reconnecting(attempt: 1, maxAttempts: 3)) + stateContinuation.yield(.connected) + stateContinuation.finish() + + eventContinuation.finish() + + // Consume events (should be empty) + for try await _ in session.events {} + + // Consume connection states + var states: [String] = [] + for await state in session.connectionState { + switch state { + case .connected: + states.append("connected") + case .reconnecting(let attempt, let max): + states.append("reconnecting(\(attempt)/\(max))") + case .disconnected: + states.append("disconnected") + } + } + #expect(states == ["connected", "reconnecting(1/3)", "connected"]) + } +}