diff --git a/providers/openai/openai.go b/providers/openai/openai.go index 7ca74b9c7..7690dd983 100644 --- a/providers/openai/openai.go +++ b/providers/openai/openai.go @@ -29,6 +29,7 @@ type options struct { project string name string useResponsesAPI bool + useWebSocket bool headers map[string]string client option.HTTPClient sdkOptions []option.RequestOption @@ -132,6 +133,15 @@ func WithUseResponsesAPI() Option { } } +// WithWebSocket enables WebSocket mode for the Responses API, providing lower-latency +// persistent connections for tool-call-heavy workflows. Requires WithUseResponsesAPI(). +// Falls back to HTTP if the WebSocket connection cannot be established. +func WithWebSocket() Option { + return func(o *options) { + o.useWebSocket = true + } +} + // WithObjectMode sets the object generation mode. func WithObjectMode(om fantasy.ObjectMode) Option { return func(o *options) { @@ -173,7 +183,11 @@ func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.Lan if objectMode == fantasy.ObjectModeJSON { objectMode = fantasy.ObjectModeAuto } - return newResponsesLanguageModel(modelID, o.options.name, client, objectMode), nil + var ws *wsTransport + if o.options.useWebSocket { + ws = newWSTransport(o.options.baseURL, o.options.apiKey, o.options.headers) + } + return newResponsesLanguageModel(modelID, o.options.name, client, objectMode, ws), nil } o.options.languageModelOptions = append(o.options.languageModelOptions, WithLanguageModelObjectMode(o.options.objectMode)) diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 39ee8e427..49ef22ac9 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -21,20 +21,22 @@ import ( const topLogprobsMax = 20 type responsesLanguageModel struct { - provider string - modelID string - client openai.Client - objectMode fantasy.ObjectMode + provider string + modelID string + client openai.Client + objectMode fantasy.ObjectMode + wsTransport *wsTransport } // newResponsesLanguageModel implements a responses api model // INFO: (kujtim) currently we do not support stored parameter we default it to false. -func newResponsesLanguageModel(modelID string, provider string, client openai.Client, objectMode fantasy.ObjectMode) responsesLanguageModel { +func newResponsesLanguageModel(modelID string, provider string, client openai.Client, objectMode fantasy.ObjectMode, ws *wsTransport) responsesLanguageModel { return responsesLanguageModel{ - modelID: modelID, - provider: provider, - client: client, - objectMode: objectMode, + modelID: modelID, + provider: provider, + client: client, + objectMode: objectMode, + wsTransport: ws, } } @@ -228,6 +230,9 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res if openaiOptions.SafetyIdentifier != nil { params.SafetyIdentifier = param.NewOpt(*openaiOptions.SafetyIdentifier) } + if openaiOptions.PreviousResponseID != nil { + params.PreviousResponseID = param.NewOpt(*openaiOptions.PreviousResponseID) + } if topLogprobs > 0 { params.TopLogprobs = param.NewOpt(int64(topLogprobs)) } @@ -668,6 +673,15 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings := o.prepareParams(call) + + if o.wsTransport != nil { + resp, err := o.generateViaWebSocket(ctx, params, warnings, call) + if err == nil { + return resp, nil + } + // Fall back to HTTP on WebSocket failure + } + response, err := o.client.Responses.New(ctx, *params) if err != nil { return nil, toProviderErr(err) @@ -806,6 +820,14 @@ func mapResponsesFinishReason(reason string, hasFunctionCall bool) fantasy.Finis func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { params, warnings := o.prepareParams(call) + if o.wsTransport != nil { + resp, err := o.streamViaWebSocket(ctx, params, warnings, call) + if err == nil { + return resp, nil + } + // Fall back to HTTP on WebSocket failure + } + stream := o.client.Responses.NewStreaming(ctx, *params) finishReason := fantasy.FinishReasonUnknown diff --git a/providers/openai/responses_options.go b/providers/openai/responses_options.go index 7e5b1a4d0..2799a0b89 100644 --- a/providers/openai/responses_options.go +++ b/providers/openai/responses_options.go @@ -97,20 +97,23 @@ const ( // ResponsesProviderOptions represents additional options for OpenAI Responses API. type ResponsesProviderOptions struct { - Include []IncludeType `json:"include"` - Instructions *string `json:"instructions"` - Logprobs any `json:"logprobs"` - MaxToolCalls *int64 `json:"max_tool_calls"` - Metadata map[string]any `json:"metadata"` - ParallelToolCalls *bool `json:"parallel_tool_calls"` - PromptCacheKey *string `json:"prompt_cache_key"` - ReasoningEffort *ReasoningEffort `json:"reasoning_effort"` - ReasoningSummary *string `json:"reasoning_summary"` - SafetyIdentifier *string `json:"safety_identifier"` - ServiceTier *ServiceTier `json:"service_tier"` - StrictJSONSchema *bool `json:"strict_json_schema"` - TextVerbosity *TextVerbosity `json:"text_verbosity"` - User *string `json:"user"` + Include []IncludeType `json:"include"` + Instructions *string `json:"instructions"` + Logprobs any `json:"logprobs"` + MaxToolCalls *int64 `json:"max_tool_calls"` + Metadata map[string]any `json:"metadata"` + ParallelToolCalls *bool `json:"parallel_tool_calls"` + PreviousResponseID *string `json:"previous_response_id"` + PromptCacheKey *string `json:"prompt_cache_key"` + ReasoningEffort *ReasoningEffort `json:"reasoning_effort"` + ReasoningSummary *string `json:"reasoning_summary"` + ResetChain *bool `json:"reset_chain"` + SafetyIdentifier *string `json:"safety_identifier"` + ServiceTier *ServiceTier `json:"service_tier"` + StrictJSONSchema *bool `json:"strict_json_schema"` + TextVerbosity *TextVerbosity `json:"text_verbosity"` + User *string `json:"user"` + GenerateWarmup *bool `json:"generate_warmup"` } // Options implements the ProviderOptions interface. diff --git a/providers/openai/responses_websocket.go b/providers/openai/responses_websocket.go new file mode 100644 index 000000000..8b5e2e854 --- /dev/null +++ b/providers/openai/responses_websocket.go @@ -0,0 +1,314 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "charm.land/fantasy" + "github.com/gorilla/websocket" +) + +// wsReconnectThreshold is how close to the 60-minute connection timeout we allow before reconnecting. +const wsReconnectThreshold = 55 * time.Minute + +// wsTransport manages a persistent WebSocket connection to the OpenAI Responses API. +type wsTransport struct { + mu sync.Mutex + conn *websocket.Conn + connectedAt time.Time + baseURL string + apiKey string + headers map[string]string + lastResponseID string + lastInputLen int // number of input items sent in the last successful request +} + +// newWSTransport creates a new WebSocket transport for the OpenAI Responses API. +func newWSTransport(baseURL, apiKey string, headers map[string]string) *wsTransport { + return &wsTransport{ + baseURL: baseURL, + apiKey: apiKey, + headers: headers, + } +} + +// wsURL converts the base URL to a WebSocket URL. +func (ws *wsTransport) wsURL() string { + url := ws.baseURL + url = strings.Replace(url, "https://", "wss://", 1) + url = strings.Replace(url, "http://", "ws://", 1) + url = strings.TrimSuffix(url, "/") + // Remove trailing /v1 if present since we add /v1/responses + url = strings.TrimSuffix(url, "/v1") + return url + "/v1/responses" +} + +// connect establishes a WebSocket connection. +func (ws *wsTransport) connect(ctx context.Context) error { + header := http.Header{} + header.Set("Authorization", "Bearer "+ws.apiKey) + for key, value := range ws.headers { + header.Set(key, value) + } + + dialer := websocket.Dialer{ + HandshakeTimeout: 30 * time.Second, + } + + conn, resp, err := dialer.DialContext(ctx, ws.wsURL(), header) + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + if err != nil { + return fmt.Errorf("websocket connect: %w", err) + } + + ws.conn = conn + ws.connectedAt = time.Now() + return nil +} + +// ensureConnected connects if not connected or reconnects if approaching the 60-minute limit. +func (ws *wsTransport) ensureConnected(ctx context.Context) error { + if ws.conn != nil && time.Since(ws.connectedAt) < wsReconnectThreshold { + return nil + } + + if ws.conn != nil { + _ = ws.conn.Close() + ws.conn = nil + } + + return ws.connect(ctx) +} + +// Close closes the WebSocket connection. +func (ws *wsTransport) Close() error { + ws.mu.Lock() + defer ws.mu.Unlock() + + if ws.conn != nil { + err := ws.conn.Close() + ws.conn = nil + return err + } + return nil +} + +// responseCreateEvent wraps the response params in a WebSocket event envelope. +type responseCreateEvent struct { + Type string `json:"type"` + Body json.RawMessage `json:"-"` +} + +// MarshalJSON implements custom marshaling to flatten the body into the event. +func (e responseCreateEvent) MarshalJSON() ([]byte, error) { + // Start with the body fields and add the type field + var bodyMap map[string]json.RawMessage + if err := json.Unmarshal(e.Body, &bodyMap); err != nil { + return nil, fmt.Errorf("unmarshal body: %w", err) + } + typeBytes, err := json.Marshal(e.Type) + if err != nil { + return nil, fmt.Errorf("marshal type: %w", err) + } + bodyMap["type"] = typeBytes + return json.Marshal(bodyMap) +} + +// wsServerEvent represents a server-sent event from the WebSocket connection. +type wsServerEvent struct { + Type string `json:"type"` + Raw json.RawMessage `json:"-"` +} + +// sendResponseCreate sends a response.create event and returns a channel of raw server events. +// The caller must hold ws.mu. +func (ws *wsTransport) sendResponseCreate(ctx context.Context, body json.RawMessage) (chan wsServerEvent, error) { + if err := ws.ensureConnected(ctx); err != nil { + return nil, err + } + + event := responseCreateEvent{ + Type: "response.create", + Body: body, + } + + data, err := event.MarshalJSON() + if err != nil { + return nil, fmt.Errorf("marshal response.create: %w", err) + } + + if err := ws.conn.WriteMessage(websocket.TextMessage, data); err != nil { + return nil, fmt.Errorf("websocket write: %w", err) + } + + events := make(chan wsServerEvent, 64) + + conn := ws.conn + go func() { + defer close(events) + + // Set a read deadline when the context is cancelled to unblock ReadMessage. + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + _ = conn.SetReadDeadline(time.Now()) + case <-done: + } + }() + + for { + _, message, err := conn.ReadMessage() + if err != nil { + // Don't emit an error event if the context was cancelled. + if ctx.Err() != nil { + return + } + events <- wsServerEvent{ + Type: "error", + Raw: mustMarshal(map[string]string{"type": "error", "code": "websocket_read_error", "message": err.Error()}), + } + return + } + + var evt wsServerEvent + if err := json.Unmarshal(message, &evt); err != nil { + continue + } + evt.Raw = message + + events <- evt + + // Terminal events + if evt.Type == "response.completed" || evt.Type == "response.incomplete" || evt.Type == "response.failed" { + return + } + if evt.Type == "error" { + return + } + } + }() + + return events, nil +} + +func mustMarshal(v any) json.RawMessage { + data, _ := json.Marshal(v) + return data +} + +// extractIncrementalInput returns only the new input items that the server hasn't +// seen yet. When chaining with previous_response_id, the server already has the +// prior context, so we only send function_call_output items and new user messages. +// Items of type "function_call" are filtered out because the server generated those +// as part of its own response output. +func (ws *wsTransport) extractIncrementalInput(fullInput json.RawMessage) (json.RawMessage, int) { + var items []json.RawMessage + if err := json.Unmarshal(fullInput, &items); err != nil { + return fullInput, 0 + } + + fullLen := len(items) + + if ws.lastResponseID == "" || ws.lastInputLen == 0 { + return fullInput, fullLen + } + + if len(items) <= ws.lastInputLen { + return fullInput, fullLen + } + + // Take only items appended since the last request. + newItems := items[ws.lastInputLen:] + + // Filter out function_call items — the server already has these from its + // own response output; sending them again would be redundant. + var incremental []json.RawMessage + for _, item := range newItems { + var parsed map[string]json.RawMessage + if err := json.Unmarshal(item, &parsed); err != nil { + incremental = append(incremental, item) + continue + } + if typeField, ok := parsed["type"]; ok { + var itemType string + if err := json.Unmarshal(typeField, &itemType); err == nil && itemType == "function_call" { + continue + } + } + incremental = append(incremental, item) + } + + result, err := json.Marshal(incremental) + if err != nil { + return fullInput, fullLen + } + return result, fullLen +} + +// applyWSOptions modifies the marshaled params JSON to add WebSocket-specific fields +// like previous_response_id (from transport state) and generate (for warmup). +// It returns the modified body and the full input item count (before any trimming) +// so callers can update lastInputLen after a successful response. +func (ws *wsTransport) applyWSOptions(body json.RawMessage, call fantasy.Call) (json.RawMessage, int) { + var bodyMap map[string]json.RawMessage + if err := json.Unmarshal(body, &bodyMap); err != nil { + return body, 0 + } + + // Handle ResetChain from provider options - must be checked before auto-chaining. + var openaiOptions *ResponsesProviderOptions + if opts, ok := call.ProviderOptions[Name]; ok { + if typedOpts, ok := opts.(*ResponsesProviderOptions); ok { + openaiOptions = typedOpts + } + } + if openaiOptions != nil && openaiOptions.ResetChain != nil && *openaiOptions.ResetChain { + ws.lastResponseID = "" + ws.lastInputLen = 0 + } + + var fullInputLen int + + // Auto-chain with previous_response_id from transport state if not explicitly set + usingPrevID := false + if _, hasPrevID := bodyMap["previous_response_id"]; !hasPrevID && ws.lastResponseID != "" { + prevIDBytes, _ := json.Marshal(ws.lastResponseID) + bodyMap["previous_response_id"] = prevIDBytes + usingPrevID = true + } else if _, hasPrevID := bodyMap["previous_response_id"]; hasPrevID { + usingPrevID = true + } + + // When chaining, send only incremental input items. + if inputField, hasInput := bodyMap["input"]; hasInput { + if usingPrevID && ws.lastInputLen > 0 { + bodyMap["input"], fullInputLen = ws.extractIncrementalInput(inputField) + } else { + // Count full input items for tracking. + var items []json.RawMessage + if err := json.Unmarshal(inputField, &items); err == nil { + fullInputLen = len(items) + } + } + } + + // Handle GenerateWarmup from provider options. + if openaiOptions != nil && openaiOptions.GenerateWarmup != nil && *openaiOptions.GenerateWarmup { + bodyMap["generate"] = json.RawMessage("false") + } + + result, err := json.Marshal(bodyMap) + if err != nil { + return body, fullInputLen + } + return result, fullInputLen +} diff --git a/providers/openai/responses_websocket_model.go b/providers/openai/responses_websocket_model.go new file mode 100644 index 000000000..1d33ea676 --- /dev/null +++ b/providers/openai/responses_websocket_model.go @@ -0,0 +1,410 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/openai/openai-go/v2/responses" +) + +// generateViaWebSocket sends a response.create event over WebSocket and collects +// the full response from streaming events. +func (o responsesLanguageModel) generateViaWebSocket(ctx context.Context, params *responses.ResponseNewParams, warnings []fantasy.CallWarning, call fantasy.Call) (*fantasy.Response, error) { + o.wsTransport.mu.Lock() + defer o.wsTransport.mu.Unlock() + + body, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("marshal params: %w", err) + } + + var fullInputLen int + body, fullInputLen = o.wsTransport.applyWSOptions(body, call) + + events, err := o.wsTransport.sendResponseCreate(ctx, body) + if err != nil { + return nil, err + } + + var content []fantasy.Content + hasFunctionCall := false + var usage fantasy.Usage + var responseErr error + + for evt := range events { + var streamEvent responses.ResponseStreamEventUnion + if err := json.Unmarshal(evt.Raw, &streamEvent); err != nil { + continue + } + + switch evt.Type { + case "response.completed", "response.incomplete": + completed := streamEvent.AsResponseCompleted() + o.wsTransport.lastResponseID = completed.Response.ID + o.wsTransport.lastInputLen = fullInputLen + + // Build content from the completed response output + content = nil // Reset — use the final response + for _, outputItem := range completed.Response.Output { + switch outputItem.Type { + case "message": + for _, contentPart := range outputItem.Content { + if contentPart.Type == "output_text" { + content = append(content, fantasy.TextContent{ + Text: contentPart.Text, + }) + for _, annotation := range contentPart.Annotations { + switch annotation.Type { + case "url_citation": + content = append(content, fantasy.SourceContent{ + SourceType: fantasy.SourceTypeURL, + ID: uuid.NewString(), + URL: annotation.URL, + Title: annotation.Title, + }) + case "file_citation": + title := "Document" + if annotation.Filename != "" { + title = annotation.Filename + } + filename := annotation.Filename + if filename == "" { + filename = annotation.FileID + } + content = append(content, fantasy.SourceContent{ + SourceType: fantasy.SourceTypeDocument, + ID: uuid.NewString(), + MediaType: "text/plain", + Title: title, + Filename: filename, + }) + } + } + } + } + case "function_call": + hasFunctionCall = true + content = append(content, fantasy.ToolCallContent{ + ProviderExecuted: false, + ToolCallID: outputItem.CallID, + ToolName: outputItem.Name, + Input: outputItem.Arguments, + }) + case "reasoning": + metadata := &ResponsesReasoningMetadata{ + ItemID: outputItem.ID, + } + if outputItem.EncryptedContent != "" { + metadata.EncryptedContent = &outputItem.EncryptedContent + } + if len(outputItem.Summary) == 0 && metadata.EncryptedContent == nil { + continue + } + summaries := outputItem.Summary + if len(summaries) == 0 { + summaries = []responses.ResponseReasoningItemSummary{{Type: "summary_text", Text: ""}} + } + for _, s := range summaries { + metadata.Summary = append(metadata.Summary, s.Text) + } + content = append(content, fantasy.ReasoningContent{ + Text: strings.Join(metadata.Summary, "\n"), + ProviderMetadata: fantasy.ProviderMetadata{ + Name: metadata, + }, + }) + } + } + + usage = fantasy.Usage{ + InputTokens: completed.Response.Usage.InputTokens, + OutputTokens: completed.Response.Usage.OutputTokens, + TotalTokens: completed.Response.Usage.InputTokens + completed.Response.Usage.OutputTokens, + } + if completed.Response.Usage.OutputTokensDetails.ReasoningTokens != 0 { + usage.ReasoningTokens = completed.Response.Usage.OutputTokensDetails.ReasoningTokens + } + if completed.Response.Usage.InputTokensDetails.CachedTokens != 0 { + usage.CacheReadTokens = completed.Response.Usage.InputTokensDetails.CachedTokens + } + + case "response.failed": + completed := streamEvent.AsResponseCompleted() + responseErr = fmt.Errorf("response failed: %s (code: %s)", + completed.Response.Error.Message, completed.Response.Error.Code) + + case "error": + errorEvent := streamEvent.AsError() + if errorEvent.Code == "previous_response_not_found" { + o.wsTransport.lastResponseID = "" + o.wsTransport.lastInputLen = 0 + return nil, fmt.Errorf("previous_response_not_found") + } + responseErr = fmt.Errorf("%s (code: %s)", errorEvent.Message, errorEvent.Code) + } + } + + if responseErr != nil { + return nil, responseErr + } + + finishReason := mapResponsesFinishReason("", hasFunctionCall) + + return &fantasy.Response{ + Content: content, + Usage: usage, + FinishReason: finishReason, + ProviderMetadata: fantasy.ProviderMetadata{}, + Warnings: warnings, + }, nil +} + +// streamViaWebSocket sends a response.create event over WebSocket and yields +// StreamParts from the server events. +func (o responsesLanguageModel) streamViaWebSocket(ctx context.Context, params *responses.ResponseNewParams, warnings []fantasy.CallWarning, call fantasy.Call) (fantasy.StreamResponse, error) { + o.wsTransport.mu.Lock() + + body, err := json.Marshal(params) + if err != nil { + o.wsTransport.mu.Unlock() + return nil, fmt.Errorf("marshal params: %w", err) + } + + var fullInputLen int + body, fullInputLen = o.wsTransport.applyWSOptions(body, call) + + events, err := o.wsTransport.sendResponseCreate(ctx, body) + if err != nil { + o.wsTransport.mu.Unlock() + return nil, err + } + + return func(yield func(fantasy.StreamPart) bool) { + defer o.wsTransport.mu.Unlock() + + if len(warnings) > 0 { + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeWarnings, + Warnings: warnings, + }) { + return + } + } + + finishReason := fantasy.FinishReasonUnknown + var usage fantasy.Usage + ongoingToolCalls := make(map[int64]*ongoingToolCall) + hasFunctionCall := false + activeReasoning := make(map[string]*reasoningState) + + for evt := range events { + var event responses.ResponseStreamEventUnion + if err := json.Unmarshal(evt.Raw, &event); err != nil { + continue + } + + switch evt.Type { + case "response.created": + _ = event.AsResponseCreated() + + case "response.output_item.added": + added := event.AsResponseOutputItemAdded() + switch added.Item.Type { + case "function_call": + ongoingToolCalls[added.OutputIndex] = &ongoingToolCall{ + toolName: added.Item.Name, + toolCallID: added.Item.CallID, + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, + ID: added.Item.CallID, + ToolCallName: added.Item.Name, + }) { + return + } + case "message": + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextStart, + ID: added.Item.ID, + }) { + return + } + case "reasoning": + metadata := &ResponsesReasoningMetadata{ + ItemID: added.Item.ID, + Summary: []string{}, + } + if added.Item.EncryptedContent != "" { + metadata.EncryptedContent = &added.Item.EncryptedContent + } + activeReasoning[added.Item.ID] = &reasoningState{metadata: metadata} + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningStart, + ID: added.Item.ID, + ProviderMetadata: fantasy.ProviderMetadata{ + Name: metadata, + }, + }) { + return + } + } + + case "response.output_item.done": + done := event.AsResponseOutputItemDone() + switch done.Item.Type { + case "function_call": + tc := ongoingToolCalls[done.OutputIndex] + if tc != nil { + delete(ongoingToolCalls, done.OutputIndex) + hasFunctionCall = true + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, + ID: done.Item.CallID, + }) { + return + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, + ID: done.Item.CallID, + ToolCallName: done.Item.Name, + ToolCallInput: done.Item.Arguments, + }) { + return + } + } + case "message": + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, + ID: done.Item.ID, + }) { + return + } + case "reasoning": + state := activeReasoning[done.Item.ID] + if state != nil { + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, + ID: done.Item.ID, + ProviderMetadata: fantasy.ProviderMetadata{ + Name: state.metadata, + }, + }) { + return + } + delete(activeReasoning, done.Item.ID) + } + } + + case "response.function_call_arguments.delta": + delta := event.AsResponseFunctionCallArgumentsDelta() + tc := ongoingToolCalls[delta.OutputIndex] + if tc != nil { + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputDelta, + ID: tc.toolCallID, + Delta: delta.Delta, + }) { + return + } + } + + case "response.output_text.delta": + textDelta := event.AsResponseOutputTextDelta() + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, + ID: textDelta.ItemID, + Delta: textDelta.Delta, + }) { + return + } + + case "response.reasoning_summary_part.added": + added := event.AsResponseReasoningSummaryPartAdded() + state := activeReasoning[added.ItemID] + if state != nil { + state.metadata.Summary = append(state.metadata.Summary, "") + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, + ID: added.ItemID, + Delta: "\n", + ProviderMetadata: fantasy.ProviderMetadata{ + Name: state.metadata, + }, + }) { + return + } + } + + case "response.reasoning_summary_text.delta": + textDelta := event.AsResponseReasoningSummaryTextDelta() + state := activeReasoning[textDelta.ItemID] + if state != nil { + if len(state.metadata.Summary)-1 >= int(textDelta.SummaryIndex) { + state.metadata.Summary[textDelta.SummaryIndex] += textDelta.Delta + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, + ID: textDelta.ItemID, + Delta: textDelta.Delta, + ProviderMetadata: fantasy.ProviderMetadata{ + Name: state.metadata, + }, + }) { + return + } + } + + case "response.completed", "response.incomplete": + completed := event.AsResponseCompleted() + o.wsTransport.lastResponseID = completed.Response.ID + o.wsTransport.lastInputLen = fullInputLen + finishReason = mapResponsesFinishReason(completed.Response.IncompleteDetails.Reason, hasFunctionCall) + usage = fantasy.Usage{ + InputTokens: completed.Response.Usage.InputTokens, + OutputTokens: completed.Response.Usage.OutputTokens, + TotalTokens: completed.Response.Usage.InputTokens + completed.Response.Usage.OutputTokens, + } + if completed.Response.Usage.OutputTokensDetails.ReasoningTokens != 0 { + usage.ReasoningTokens = completed.Response.Usage.OutputTokensDetails.ReasoningTokens + } + if completed.Response.Usage.InputTokensDetails.CachedTokens != 0 { + usage.CacheReadTokens = completed.Response.Usage.InputTokensDetails.CachedTokens + } + + case "response.failed": + completed := event.AsResponseCompleted() + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: fmt.Errorf("response failed: %s (code: %s)", completed.Response.Error.Message, completed.Response.Error.Code), + }) { + return + } + return + + case "error": + errorEvent := event.AsError() + if errorEvent.Code == "previous_response_not_found" { + o.wsTransport.lastResponseID = "" + o.wsTransport.lastInputLen = 0 + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: fmt.Errorf("response error: %s (code: %s)", errorEvent.Message, errorEvent.Code), + }) { + return + } + return + } + } + + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeFinish, + Usage: usage, + FinishReason: finishReason, + }) + }, nil +} diff --git a/providers/openai/responses_websocket_test.go b/providers/openai/responses_websocket_test.go new file mode 100644 index 000000000..7851e4f94 --- /dev/null +++ b/providers/openai/responses_websocket_test.go @@ -0,0 +1,406 @@ +package openai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "charm.land/fantasy" + "github.com/gorilla/websocket" +) + +// mockWSServer creates a test WebSocket server that sends predefined events. +func mockWSServer(t *testing.T, handler func(conn *websocket.Conn)) *httptest.Server { + t.Helper() + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatalf("upgrade: %v", err) + return + } + defer conn.Close() + handler(conn) + })) + return server +} + +func wsURLFromHTTP(httpURL string) string { + return strings.Replace(httpURL, "http://", "ws://", 1) +} + +func TestWSTransport_Connect(t *testing.T) { + server := mockWSServer(t, func(conn *websocket.Conn) { + // Just accept and hold connection + for { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + }) + defer server.Close() + + ws := newWSTransport(wsURLFromHTTP(server.URL), "test-key", nil) + // Override wsURL to point at test server + ws.baseURL = wsURLFromHTTP(server.URL) + + err := ws.connect(context.Background()) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer ws.Close() + + if ws.conn == nil { + t.Fatal("expected conn to be set") + } + if ws.connectedAt.IsZero() { + t.Fatal("expected connectedAt to be set") + } +} + +func TestWSTransport_EnsureConnected_Reconnect(t *testing.T) { + var connectCount int + var mu sync.Mutex + + server := mockWSServer(t, func(conn *websocket.Conn) { + mu.Lock() + connectCount++ + mu.Unlock() + for { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + }) + defer server.Close() + + ws := newWSTransport(wsURLFromHTTP(server.URL), "test-key", nil) + ws.baseURL = wsURLFromHTTP(server.URL) + + // First connect + err := ws.ensureConnected(context.Background()) + if err != nil { + t.Fatalf("first connect: %v", err) + } + + // Should not reconnect (within threshold) + err = ws.ensureConnected(context.Background()) + if err != nil { + t.Fatalf("second connect: %v", err) + } + + mu.Lock() + if connectCount != 1 { + t.Fatalf("expected 1 connection, got %d", connectCount) + } + mu.Unlock() + + // Simulate expired connection by backdating connectedAt + ws.connectedAt = time.Now().Add(-56 * time.Minute) + + err = ws.ensureConnected(context.Background()) + if err != nil { + t.Fatalf("reconnect: %v", err) + } + defer ws.Close() + + // Wait a moment for the server to register the new connection + time.Sleep(50 * time.Millisecond) + + mu.Lock() + if connectCount != 2 { + t.Fatalf("expected 2 connections after reconnect, got %d", connectCount) + } + mu.Unlock() +} + +func TestWSTransport_SendResponseCreate(t *testing.T) { + server := mockWSServer(t, func(conn *websocket.Conn) { + // Read the response.create message + _, message, err := conn.ReadMessage() + if err != nil { + t.Errorf("read: %v", err) + return + } + + var evt map[string]json.RawMessage + if err := json.Unmarshal(message, &evt); err != nil { + t.Errorf("unmarshal: %v", err) + return + } + + var eventType string + json.Unmarshal(evt["type"], &eventType) + if eventType != "response.create" { + t.Errorf("expected type response.create, got %s", eventType) + return + } + + // Send response events + events := []string{ + `{"type":"response.created","response":{"id":"resp_123","status":"in_progress"}}`, + `{"type":"response.output_item.added","output_index":0,"item":{"id":"item_1","type":"message","role":"assistant","content":[]}}`, + `{"type":"response.output_text.delta","output_index":0,"content_index":0,"item_id":"item_1","delta":"Hello"}`, + `{"type":"response.output_text.delta","output_index":0,"content_index":0,"item_id":"item_1","delta":" world"}`, + `{"type":"response.completed","response":{"id":"resp_123","status":"completed","output":[{"id":"item_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello world"}]}],"usage":{"input_tokens":10,"output_tokens":5}}}`, + } + for _, event := range events { + if err := conn.WriteMessage(websocket.TextMessage, []byte(event)); err != nil { + return + } + } + }) + defer server.Close() + + ws := newWSTransport(wsURLFromHTTP(server.URL), "test-key", nil) + ws.baseURL = wsURLFromHTTP(server.URL) + + ws.mu.Lock() + defer ws.mu.Unlock() + + body := json.RawMessage(`{"model":"gpt-4o","input":[]}`) + events, err := ws.sendResponseCreate(context.Background(), body) + if err != nil { + t.Fatalf("sendResponseCreate: %v", err) + } + + var eventTypes []string + for evt := range events { + eventTypes = append(eventTypes, evt.Type) + } + + expected := []string{ + "response.created", + "response.output_item.added", + "response.output_text.delta", + "response.output_text.delta", + "response.completed", + } + + if len(eventTypes) != len(expected) { + t.Fatalf("expected %d events, got %d: %v", len(expected), len(eventTypes), eventTypes) + } + + for i, eventType := range eventTypes { + if eventType != expected[i] { + t.Errorf("event %d: expected %s, got %s", i, expected[i], eventType) + } + } +} + +func TestWSTransport_PreviousResponseIDChaining(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + ws.lastResponseID = "resp_prev_123" + + call := fantasy.Call{ + ProviderOptions: fantasy.ProviderOptions{}, + } + + body := json.RawMessage(`{"model":"gpt-4o","input":[]}`) + result, _ := ws.applyWSOptions(body, call) + + var resultMap map[string]json.RawMessage + if err := json.Unmarshal(result, &resultMap); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + + prevIDRaw, ok := resultMap["previous_response_id"] + if !ok { + t.Fatal("expected previous_response_id in result") + } + + var prevID string + json.Unmarshal(prevIDRaw, &prevID) + if prevID != "resp_prev_123" { + t.Errorf("expected resp_prev_123, got %s", prevID) + } +} + +func TestWSTransport_GenerateWarmup(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + + warmup := true + call := fantasy.Call{ + ProviderOptions: fantasy.ProviderOptions{ + Name: &ResponsesProviderOptions{ + GenerateWarmup: &warmup, + }, + }, + } + + body := json.RawMessage(`{"model":"gpt-4o","input":[]}`) + result, _ := ws.applyWSOptions(body, call) + + var resultMap map[string]json.RawMessage + if err := json.Unmarshal(result, &resultMap); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + + genRaw, ok := resultMap["generate"] + if !ok { + t.Fatal("expected generate in result") + } + + if string(genRaw) != "false" { + t.Errorf("expected generate=false, got %s", string(genRaw)) + } +} + +func TestWSTransport_ExplicitPreviousResponseIDOverridesAuto(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + ws.lastResponseID = "resp_auto_123" + + // Set explicit previous_response_id in the body + body := json.RawMessage(`{"model":"gpt-4o","input":[],"previous_response_id":"resp_explicit_456"}`) + call := fantasy.Call{} + result, _ := ws.applyWSOptions(body, call) + + var resultMap map[string]json.RawMessage + if err := json.Unmarshal(result, &resultMap); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + + var prevID string + json.Unmarshal(resultMap["previous_response_id"], &prevID) + if prevID != "resp_explicit_456" { + t.Errorf("expected explicit ID resp_explicit_456, got %s", prevID) + } +} + +func TestWSTransport_FallbackToHTTP(t *testing.T) { + // Create a provider with WebSocket enabled but no server running + provider, err := New( + WithAPIKey("test-key"), + WithBaseURL("https://localhost:1"), + WithUseResponsesAPI(), + WithWebSocket(), + ) + if err != nil { + t.Fatalf("new provider: %v", err) + } + + model, err := provider.LanguageModel(context.Background(), "gpt-4o") + if err != nil { + t.Fatalf("language model: %v", err) + } + + // The model should be a responsesLanguageModel with wsTransport set + rlm, ok := model.(responsesLanguageModel) + if !ok { + t.Fatal("expected responsesLanguageModel") + } + if rlm.wsTransport == nil { + t.Fatal("expected wsTransport to be set") + } +} + +func TestWSTransport_IncrementalInput(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + ws.lastResponseID = "resp_001" + ws.lastInputLen = 3 + + // Simulate full input: 3 old items + 2 new (1 function_call + 1 function_call_output). + // The function_call should be filtered out since the server generated it. + fullInput := json.RawMessage(`[ + {"type":"message","role":"system","content":"sys"}, + {"type":"message","role":"user","content":"hello"}, + {"type":"message","role":"assistant","content":"hi"}, + {"type":"function_call","call_id":"call_1","name":"search","arguments":"{}"}, + {"type":"function_call_output","call_id":"call_1","output":"result"} + ]`) + + got, fullLen := ws.extractIncrementalInput(fullInput) + if fullLen != 5 { + t.Fatalf("expected fullLen=5, got %d", fullLen) + } + + var items []map[string]interface{} + if err := json.Unmarshal(got, &items); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + + if len(items) != 1 { + t.Fatalf("expected 1 incremental item, got %d: %s", len(items), string(got)) + } + if items[0]["type"] != "function_call_output" { + t.Errorf("expected function_call_output, got %v", items[0]["type"]) + } +} + +func TestWSTransport_IncrementalInput_FirstCall(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + // No lastResponseID — should return full input. + + fullInput := json.RawMessage(`[{"type":"message","role":"user","content":"hi"}]`) + got, fullLen := ws.extractIncrementalInput(fullInput) + if fullLen != 1 { + t.Fatalf("expected fullLen=1, got %d", fullLen) + } + if string(got) != string(fullInput) { + t.Errorf("expected full input returned on first call") + } +} + +func TestWSTransport_IncrementalInput_AppliedViaApplyWSOptions(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + ws.lastResponseID = "resp_001" + ws.lastInputLen = 2 + + body := json.RawMessage(`{"model":"gpt-4o","input":[ + {"type":"message","role":"user","content":"hello"}, + {"type":"message","role":"assistant","content":"hi"}, + {"type":"function_call_output","call_id":"call_1","output":"42"} + ]}`) + + result, fullLen := ws.applyWSOptions(body, fantasy.Call{}) + if fullLen != 3 { + t.Fatalf("expected fullLen=3, got %d", fullLen) + } + + var resultMap map[string]json.RawMessage + if err := json.Unmarshal(result, &resultMap); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + var items []map[string]interface{} + if err := json.Unmarshal(resultMap["input"], &items); err != nil { + t.Fatalf("unmarshal input: %v", err) + } + + if len(items) != 1 { + t.Fatalf("expected 1 incremental input item, got %d", len(items)) + } + if items[0]["type"] != "function_call_output" { + t.Errorf("expected function_call_output, got %v", items[0]["type"]) + } +} + +func TestWSURL(t *testing.T) { + tests := []struct { + baseURL string + expected string + }{ + {"https://api.openai.com/v1", "wss://api.openai.com/v1/responses"}, + {"https://custom.api.com/v1", "wss://custom.api.com/v1/responses"}, + {"http://localhost:8080/v1", "ws://localhost:8080/v1/responses"}, + {"https://api.openai.com/v1/", "wss://api.openai.com/v1/responses"}, + } + + for _, tt := range tests { + ws := newWSTransport(tt.baseURL, "key", nil) + ws.baseURL = tt.baseURL + got := ws.wsURL() + if got != tt.expected { + t.Errorf("wsURL(%s) = %s, want %s", tt.baseURL, got, tt.expected) + } + } +}