From 514e049400b40be5c395f2b661f4c3a82b88c75d Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Tue, 24 Feb 2026 18:12:49 +0000 Subject: [PATCH 1/6] feat(openai): add WebSocket mode for Responses API Add WebSocket transport support for the OpenAI Responses API, enabling lower-latency persistent connections for tool-call-heavy workflows. Key features: - wsTransport manages WebSocket connection lifecycle with automatic reconnection before the 60-minute connection limit - previous_response_id auto-chaining for incremental continuation - generate:false warmup support via GenerateWarmup provider option - Falls back to HTTP transparently on WebSocket connection failure - One in-flight response at a time per connection (mutex-protected) New provider options: - WithWebSocket() enables WebSocket mode (requires WithUseResponsesAPI) - PreviousResponseID on ResponsesProviderOptions for explicit chaining - GenerateWarmup on ResponsesProviderOptions for prefill/warmup The WebSocket events use the same JSON structure as HTTP SSE events, so both Generate() and Stream() reuse existing event parsing logic. No changes to the LanguageModel interface or consumer-facing API. --- providers/openai/openai.go | 16 +- providers/openai/responses_language_model.go | 40 +- providers/openai/responses_options.go | 30 +- providers/openai/responses_websocket.go | 223 ++++++++++ providers/openai/responses_websocket_model.go | 392 ++++++++++++++++++ providers/openai/responses_websocket_test.go | 325 +++++++++++++++ 6 files changed, 1002 insertions(+), 24 deletions(-) create mode 100644 providers/openai/responses_websocket.go create mode 100644 providers/openai/responses_websocket_model.go create mode 100644 providers/openai/responses_websocket_test.go diff --git a/providers/openai/openai.go b/providers/openai/openai.go index 7ca74b9c7..7690dd983 100644 --- a/providers/openai/openai.go +++ b/providers/openai/openai.go @@ -29,6 +29,7 @@ type options struct { project string name string useResponsesAPI bool + useWebSocket bool headers map[string]string client option.HTTPClient sdkOptions []option.RequestOption @@ -132,6 +133,15 @@ func WithUseResponsesAPI() Option { } } +// WithWebSocket enables WebSocket mode for the Responses API, providing lower-latency +// persistent connections for tool-call-heavy workflows. Requires WithUseResponsesAPI(). +// Falls back to HTTP if the WebSocket connection cannot be established. +func WithWebSocket() Option { + return func(o *options) { + o.useWebSocket = true + } +} + // WithObjectMode sets the object generation mode. func WithObjectMode(om fantasy.ObjectMode) Option { return func(o *options) { @@ -173,7 +183,11 @@ func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.Lan if objectMode == fantasy.ObjectModeJSON { objectMode = fantasy.ObjectModeAuto } - return newResponsesLanguageModel(modelID, o.options.name, client, objectMode), nil + var ws *wsTransport + if o.options.useWebSocket { + ws = newWSTransport(o.options.baseURL, o.options.apiKey, o.options.headers) + } + return newResponsesLanguageModel(modelID, o.options.name, client, objectMode, ws), nil } o.options.languageModelOptions = append(o.options.languageModelOptions, WithLanguageModelObjectMode(o.options.objectMode)) diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 39ee8e427..49ef22ac9 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -21,20 +21,22 @@ import ( const topLogprobsMax = 20 type responsesLanguageModel struct { - provider string - modelID string - client openai.Client - objectMode fantasy.ObjectMode + provider string + modelID string + client openai.Client + objectMode fantasy.ObjectMode + wsTransport *wsTransport } // newResponsesLanguageModel implements a responses api model // INFO: (kujtim) currently we do not support stored parameter we default it to false. -func newResponsesLanguageModel(modelID string, provider string, client openai.Client, objectMode fantasy.ObjectMode) responsesLanguageModel { +func newResponsesLanguageModel(modelID string, provider string, client openai.Client, objectMode fantasy.ObjectMode, ws *wsTransport) responsesLanguageModel { return responsesLanguageModel{ - modelID: modelID, - provider: provider, - client: client, - objectMode: objectMode, + modelID: modelID, + provider: provider, + client: client, + objectMode: objectMode, + wsTransport: ws, } } @@ -228,6 +230,9 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res if openaiOptions.SafetyIdentifier != nil { params.SafetyIdentifier = param.NewOpt(*openaiOptions.SafetyIdentifier) } + if openaiOptions.PreviousResponseID != nil { + params.PreviousResponseID = param.NewOpt(*openaiOptions.PreviousResponseID) + } if topLogprobs > 0 { params.TopLogprobs = param.NewOpt(int64(topLogprobs)) } @@ -668,6 +673,15 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings := o.prepareParams(call) + + if o.wsTransport != nil { + resp, err := o.generateViaWebSocket(ctx, params, warnings, call) + if err == nil { + return resp, nil + } + // Fall back to HTTP on WebSocket failure + } + response, err := o.client.Responses.New(ctx, *params) if err != nil { return nil, toProviderErr(err) @@ -806,6 +820,14 @@ func mapResponsesFinishReason(reason string, hasFunctionCall bool) fantasy.Finis func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { params, warnings := o.prepareParams(call) + if o.wsTransport != nil { + resp, err := o.streamViaWebSocket(ctx, params, warnings, call) + if err == nil { + return resp, nil + } + // Fall back to HTTP on WebSocket failure + } + stream := o.client.Responses.NewStreaming(ctx, *params) finishReason := fantasy.FinishReasonUnknown diff --git a/providers/openai/responses_options.go b/providers/openai/responses_options.go index 7e5b1a4d0..64f8abdcc 100644 --- a/providers/openai/responses_options.go +++ b/providers/openai/responses_options.go @@ -97,20 +97,22 @@ const ( // ResponsesProviderOptions represents additional options for OpenAI Responses API. type ResponsesProviderOptions struct { - Include []IncludeType `json:"include"` - Instructions *string `json:"instructions"` - Logprobs any `json:"logprobs"` - MaxToolCalls *int64 `json:"max_tool_calls"` - Metadata map[string]any `json:"metadata"` - ParallelToolCalls *bool `json:"parallel_tool_calls"` - PromptCacheKey *string `json:"prompt_cache_key"` - ReasoningEffort *ReasoningEffort `json:"reasoning_effort"` - ReasoningSummary *string `json:"reasoning_summary"` - SafetyIdentifier *string `json:"safety_identifier"` - ServiceTier *ServiceTier `json:"service_tier"` - StrictJSONSchema *bool `json:"strict_json_schema"` - TextVerbosity *TextVerbosity `json:"text_verbosity"` - User *string `json:"user"` + Include []IncludeType `json:"include"` + Instructions *string `json:"instructions"` + Logprobs any `json:"logprobs"` + MaxToolCalls *int64 `json:"max_tool_calls"` + Metadata map[string]any `json:"metadata"` + ParallelToolCalls *bool `json:"parallel_tool_calls"` + PreviousResponseID *string `json:"previous_response_id"` + PromptCacheKey *string `json:"prompt_cache_key"` + ReasoningEffort *ReasoningEffort `json:"reasoning_effort"` + ReasoningSummary *string `json:"reasoning_summary"` + SafetyIdentifier *string `json:"safety_identifier"` + ServiceTier *ServiceTier `json:"service_tier"` + StrictJSONSchema *bool `json:"strict_json_schema"` + TextVerbosity *TextVerbosity `json:"text_verbosity"` + User *string `json:"user"` + GenerateWarmup *bool `json:"generate_warmup"` } // Options implements the ProviderOptions interface. diff --git a/providers/openai/responses_websocket.go b/providers/openai/responses_websocket.go new file mode 100644 index 000000000..ff218a443 --- /dev/null +++ b/providers/openai/responses_websocket.go @@ -0,0 +1,223 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "charm.land/fantasy" + "github.com/gorilla/websocket" +) + +// wsReconnectThreshold is how close to the 60-minute connection timeout we allow before reconnecting. +const wsReconnectThreshold = 55 * time.Minute + +// wsTransport manages a persistent WebSocket connection to the OpenAI Responses API. +type wsTransport struct { + mu sync.Mutex + conn *websocket.Conn + connectedAt time.Time + baseURL string + apiKey string + headers map[string]string + lastResponseID string +} + +// newWSTransport creates a new WebSocket transport for the OpenAI Responses API. +func newWSTransport(baseURL, apiKey string, headers map[string]string) *wsTransport { + return &wsTransport{ + baseURL: baseURL, + apiKey: apiKey, + headers: headers, + } +} + +// wsURL converts the base URL to a WebSocket URL. +func (ws *wsTransport) wsURL() string { + url := ws.baseURL + url = strings.Replace(url, "https://", "wss://", 1) + url = strings.Replace(url, "http://", "ws://", 1) + url = strings.TrimSuffix(url, "/") + // Remove trailing /v1 if present since we add /v1/responses + url = strings.TrimSuffix(url, "/v1") + return url + "/v1/responses" +} + +// connect establishes a WebSocket connection. +func (ws *wsTransport) connect(ctx context.Context) error { + header := http.Header{} + header.Set("Authorization", "Bearer "+ws.apiKey) + for key, value := range ws.headers { + header.Set(key, value) + } + + dialer := websocket.Dialer{ + HandshakeTimeout: 30 * time.Second, + } + + conn, _, err := dialer.DialContext(ctx, ws.wsURL(), header) + if err != nil { + return fmt.Errorf("websocket connect: %w", err) + } + + ws.conn = conn + ws.connectedAt = time.Now() + return nil +} + +// ensureConnected connects if not connected or reconnects if approaching the 60-minute limit. +func (ws *wsTransport) ensureConnected(ctx context.Context) error { + if ws.conn != nil && time.Since(ws.connectedAt) < wsReconnectThreshold { + return nil + } + + if ws.conn != nil { + ws.conn.Close() + ws.conn = nil + } + + return ws.connect(ctx) +} + +// Close closes the WebSocket connection. +func (ws *wsTransport) Close() error { + ws.mu.Lock() + defer ws.mu.Unlock() + + if ws.conn != nil { + err := ws.conn.Close() + ws.conn = nil + return err + } + return nil +} + +// responseCreateEvent wraps the response params in a WebSocket event envelope. +type responseCreateEvent struct { + Type string `json:"type"` + Body json.RawMessage `json:"-"` +} + +// MarshalJSON implements custom marshaling to flatten the body into the event. +func (e responseCreateEvent) MarshalJSON() ([]byte, error) { + // Start with the body fields and add the type field + var bodyMap map[string]json.RawMessage + if err := json.Unmarshal(e.Body, &bodyMap); err != nil { + return nil, fmt.Errorf("unmarshal body: %w", err) + } + typeBytes, err := json.Marshal(e.Type) + if err != nil { + return nil, fmt.Errorf("marshal type: %w", err) + } + bodyMap["type"] = typeBytes + return json.Marshal(bodyMap) +} + +// wsServerEvent represents a server-sent event from the WebSocket connection. +type wsServerEvent struct { + Type string `json:"type"` + Raw json.RawMessage `json:"-"` +} + +// sendResponseCreate sends a response.create event and returns a channel of raw server events. +// The caller must hold ws.mu. +func (ws *wsTransport) sendResponseCreate(ctx context.Context, body json.RawMessage) (chan wsServerEvent, error) { + if err := ws.ensureConnected(ctx); err != nil { + return nil, err + } + + event := responseCreateEvent{ + Type: "response.create", + Body: body, + } + + data, err := event.MarshalJSON() + if err != nil { + return nil, fmt.Errorf("marshal response.create: %w", err) + } + + if err := ws.conn.WriteMessage(websocket.TextMessage, data); err != nil { + return nil, fmt.Errorf("websocket write: %w", err) + } + + events := make(chan wsServerEvent, 64) + + go func() { + defer close(events) + for { + select { + case <-ctx.Done(): + return + default: + } + + _, message, err := ws.conn.ReadMessage() + if err != nil { + events <- wsServerEvent{ + Type: "error", + Raw: mustMarshal(map[string]string{"type": "error", "code": "websocket_read_error", "message": err.Error()}), + } + return + } + + var evt wsServerEvent + if err := json.Unmarshal(message, &evt); err != nil { + continue + } + evt.Raw = message + + events <- evt + + // Terminal events + if evt.Type == "response.completed" || evt.Type == "response.incomplete" || evt.Type == "response.failed" { + return + } + if evt.Type == "error" { + return + } + } + }() + + return events, nil +} + +func mustMarshal(v any) json.RawMessage { + data, _ := json.Marshal(v) + return data +} + +// applyWSOptions modifies the marshaled params JSON to add WebSocket-specific fields +// like previous_response_id (from transport state) and generate (for warmup). +func (ws *wsTransport) applyWSOptions(body json.RawMessage, call fantasy.Call) json.RawMessage { + var bodyMap map[string]json.RawMessage + if err := json.Unmarshal(body, &bodyMap); err != nil { + return body + } + + // Auto-chain with previous_response_id from transport state if not explicitly set + if _, hasPrevID := bodyMap["previous_response_id"]; !hasPrevID && ws.lastResponseID != "" { + prevIDBytes, _ := json.Marshal(ws.lastResponseID) + bodyMap["previous_response_id"] = prevIDBytes + } + + // Handle GenerateWarmup from provider options + var openaiOptions *ResponsesProviderOptions + if opts, ok := call.ProviderOptions[Name]; ok { + if typedOpts, ok := opts.(*ResponsesProviderOptions); ok { + openaiOptions = typedOpts + } + } + if openaiOptions != nil && openaiOptions.GenerateWarmup != nil && *openaiOptions.GenerateWarmup { + bodyMap["generate"] = json.RawMessage("false") + } + + result, err := json.Marshal(bodyMap) + if err != nil { + return body + } + return result +} diff --git a/providers/openai/responses_websocket_model.go b/providers/openai/responses_websocket_model.go new file mode 100644 index 000000000..66cafc63f --- /dev/null +++ b/providers/openai/responses_websocket_model.go @@ -0,0 +1,392 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/openai/openai-go/v2/responses" +) + +// generateViaWebSocket sends a response.create event over WebSocket and collects +// the full response from streaming events. +func (o responsesLanguageModel) generateViaWebSocket(ctx context.Context, params *responses.ResponseNewParams, warnings []fantasy.CallWarning, call fantasy.Call) (*fantasy.Response, error) { + o.wsTransport.mu.Lock() + defer o.wsTransport.mu.Unlock() + + body, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("marshal params: %w", err) + } + + body = o.wsTransport.applyWSOptions(body, call) + + events, err := o.wsTransport.sendResponseCreate(ctx, body) + if err != nil { + return nil, err + } + + var content []fantasy.Content + hasFunctionCall := false + var usage fantasy.Usage + var responseErr error + + for evt := range events { + var streamEvent responses.ResponseStreamEventUnion + if err := json.Unmarshal(evt.Raw, &streamEvent); err != nil { + continue + } + + switch evt.Type { + case "response.completed", "response.incomplete": + completed := streamEvent.AsResponseCompleted() + o.wsTransport.lastResponseID = completed.Response.ID + + // Build content from the completed response output + content = nil // Reset — use the final response + for _, outputItem := range completed.Response.Output { + switch outputItem.Type { + case "message": + for _, contentPart := range outputItem.Content { + if contentPart.Type == "output_text" { + content = append(content, fantasy.TextContent{ + Text: contentPart.Text, + }) + for _, annotation := range contentPart.Annotations { + switch annotation.Type { + case "url_citation": + content = append(content, fantasy.SourceContent{ + SourceType: fantasy.SourceTypeURL, + ID: uuid.NewString(), + URL: annotation.URL, + Title: annotation.Title, + }) + case "file_citation": + title := "Document" + if annotation.Filename != "" { + title = annotation.Filename + } + filename := annotation.Filename + if filename == "" { + filename = annotation.FileID + } + content = append(content, fantasy.SourceContent{ + SourceType: fantasy.SourceTypeDocument, + ID: uuid.NewString(), + MediaType: "text/plain", + Title: title, + Filename: filename, + }) + } + } + } + } + case "function_call": + hasFunctionCall = true + content = append(content, fantasy.ToolCallContent{ + ProviderExecuted: false, + ToolCallID: outputItem.CallID, + ToolName: outputItem.Name, + Input: outputItem.Arguments, + }) + case "reasoning": + metadata := &ResponsesReasoningMetadata{ + ItemID: outputItem.ID, + } + if outputItem.EncryptedContent != "" { + metadata.EncryptedContent = &outputItem.EncryptedContent + } + if len(outputItem.Summary) == 0 && metadata.EncryptedContent == nil { + continue + } + summaries := outputItem.Summary + if len(summaries) == 0 { + summaries = []responses.ResponseReasoningItemSummary{{Type: "summary_text", Text: ""}} + } + for _, s := range summaries { + metadata.Summary = append(metadata.Summary, s.Text) + } + content = append(content, fantasy.ReasoningContent{ + Text: strings.Join(metadata.Summary, "\n"), + ProviderMetadata: fantasy.ProviderMetadata{ + Name: metadata, + }, + }) + } + } + + usage = fantasy.Usage{ + InputTokens: completed.Response.Usage.InputTokens, + OutputTokens: completed.Response.Usage.OutputTokens, + TotalTokens: completed.Response.Usage.InputTokens + completed.Response.Usage.OutputTokens, + } + if completed.Response.Usage.OutputTokensDetails.ReasoningTokens != 0 { + usage.ReasoningTokens = completed.Response.Usage.OutputTokensDetails.ReasoningTokens + } + if completed.Response.Usage.InputTokensDetails.CachedTokens != 0 { + usage.CacheReadTokens = completed.Response.Usage.InputTokensDetails.CachedTokens + } + + case "error": + errorEvent := streamEvent.AsError() + if errorEvent.Code == "previous_response_not_found" { + o.wsTransport.lastResponseID = "" + return nil, fmt.Errorf("previous_response_not_found") + } + responseErr = fmt.Errorf("%s (code: %s)", errorEvent.Message, errorEvent.Code) + } + } + + if responseErr != nil { + return nil, responseErr + } + + finishReason := fantasy.FinishReasonStop + if hasFunctionCall { + finishReason = fantasy.FinishReasonToolCalls + } + + return &fantasy.Response{ + Content: content, + Usage: usage, + FinishReason: finishReason, + ProviderMetadata: fantasy.ProviderMetadata{}, + Warnings: warnings, + }, nil +} + +// streamViaWebSocket sends a response.create event over WebSocket and yields +// StreamParts from the server events. +func (o responsesLanguageModel) streamViaWebSocket(ctx context.Context, params *responses.ResponseNewParams, warnings []fantasy.CallWarning, call fantasy.Call) (fantasy.StreamResponse, error) { + o.wsTransport.mu.Lock() + + body, err := json.Marshal(params) + if err != nil { + o.wsTransport.mu.Unlock() + return nil, fmt.Errorf("marshal params: %w", err) + } + + body = o.wsTransport.applyWSOptions(body, call) + + events, err := o.wsTransport.sendResponseCreate(ctx, body) + if err != nil { + o.wsTransport.mu.Unlock() + return nil, err + } + + return func(yield func(fantasy.StreamPart) bool) { + defer o.wsTransport.mu.Unlock() + + if len(warnings) > 0 { + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeWarnings, + Warnings: warnings, + }) { + return + } + } + + finishReason := fantasy.FinishReasonUnknown + var usage fantasy.Usage + ongoingToolCalls := make(map[int64]*ongoingToolCall) + hasFunctionCall := false + activeReasoning := make(map[string]*reasoningState) + + for evt := range events { + var event responses.ResponseStreamEventUnion + if err := json.Unmarshal(evt.Raw, &event); err != nil { + continue + } + + switch evt.Type { + case "response.created": + _ = event.AsResponseCreated() + + case "response.output_item.added": + added := event.AsResponseOutputItemAdded() + switch added.Item.Type { + case "function_call": + ongoingToolCalls[added.OutputIndex] = &ongoingToolCall{ + toolName: added.Item.Name, + toolCallID: added.Item.CallID, + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, + ID: added.Item.CallID, + ToolCallName: added.Item.Name, + }) { + return + } + case "message": + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextStart, + ID: added.Item.ID, + }) { + return + } + case "reasoning": + metadata := &ResponsesReasoningMetadata{ + ItemID: added.Item.ID, + Summary: []string{}, + } + if added.Item.EncryptedContent != "" { + metadata.EncryptedContent = &added.Item.EncryptedContent + } + activeReasoning[added.Item.ID] = &reasoningState{metadata: metadata} + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningStart, + ID: added.Item.ID, + ProviderMetadata: fantasy.ProviderMetadata{ + Name: metadata, + }, + }) { + return + } + } + + case "response.output_item.done": + done := event.AsResponseOutputItemDone() + switch done.Item.Type { + case "function_call": + tc := ongoingToolCalls[done.OutputIndex] + if tc != nil { + delete(ongoingToolCalls, done.OutputIndex) + hasFunctionCall = true + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, + ID: done.Item.CallID, + }) { + return + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, + ID: done.Item.CallID, + ToolCallName: done.Item.Name, + ToolCallInput: done.Item.Arguments, + }) { + return + } + } + case "message": + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextEnd, + ID: done.Item.ID, + }) { + return + } + case "reasoning": + state := activeReasoning[done.Item.ID] + if state != nil { + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningEnd, + ID: done.Item.ID, + ProviderMetadata: fantasy.ProviderMetadata{ + Name: state.metadata, + }, + }) { + return + } + delete(activeReasoning, done.Item.ID) + } + } + + case "response.function_call_arguments.delta": + delta := event.AsResponseFunctionCallArgumentsDelta() + tc := ongoingToolCalls[delta.OutputIndex] + if tc != nil { + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputDelta, + ID: tc.toolCallID, + Delta: delta.Delta, + }) { + return + } + } + + case "response.output_text.delta": + textDelta := event.AsResponseOutputTextDelta() + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, + ID: textDelta.ItemID, + Delta: textDelta.Delta, + }) { + return + } + + case "response.reasoning_summary_part.added": + added := event.AsResponseReasoningSummaryPartAdded() + state := activeReasoning[added.ItemID] + if state != nil { + state.metadata.Summary = append(state.metadata.Summary, "") + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, + ID: added.ItemID, + Delta: "\n", + ProviderMetadata: fantasy.ProviderMetadata{ + Name: state.metadata, + }, + }) { + return + } + } + + case "response.reasoning_summary_text.delta": + textDelta := event.AsResponseReasoningSummaryTextDelta() + state := activeReasoning[textDelta.ItemID] + if state != nil { + if len(state.metadata.Summary)-1 >= int(textDelta.SummaryIndex) { + state.metadata.Summary[textDelta.SummaryIndex] += textDelta.Delta + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeReasoningDelta, + ID: textDelta.ItemID, + Delta: textDelta.Delta, + ProviderMetadata: fantasy.ProviderMetadata{ + Name: state.metadata, + }, + }) { + return + } + } + + case "response.completed", "response.incomplete": + completed := event.AsResponseCompleted() + o.wsTransport.lastResponseID = completed.Response.ID + finishReason = mapResponsesFinishReason(completed.Response.IncompleteDetails.Reason, hasFunctionCall) + usage = fantasy.Usage{ + InputTokens: completed.Response.Usage.InputTokens, + OutputTokens: completed.Response.Usage.OutputTokens, + TotalTokens: completed.Response.Usage.InputTokens + completed.Response.Usage.OutputTokens, + } + if completed.Response.Usage.OutputTokensDetails.ReasoningTokens != 0 { + usage.ReasoningTokens = completed.Response.Usage.OutputTokensDetails.ReasoningTokens + } + if completed.Response.Usage.InputTokensDetails.CachedTokens != 0 { + usage.CacheReadTokens = completed.Response.Usage.InputTokensDetails.CachedTokens + } + + case "error": + errorEvent := event.AsError() + if errorEvent.Code == "previous_response_not_found" { + o.wsTransport.lastResponseID = "" + } + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: fmt.Errorf("response error: %s (code: %s)", errorEvent.Message, errorEvent.Code), + }) { + return + } + return + } + } + + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeFinish, + Usage: usage, + FinishReason: finishReason, + }) + }, nil +} diff --git a/providers/openai/responses_websocket_test.go b/providers/openai/responses_websocket_test.go new file mode 100644 index 000000000..6c6711029 --- /dev/null +++ b/providers/openai/responses_websocket_test.go @@ -0,0 +1,325 @@ +package openai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "charm.land/fantasy" + "github.com/gorilla/websocket" +) + +// mockWSServer creates a test WebSocket server that sends predefined events. +func mockWSServer(t *testing.T, handler func(conn *websocket.Conn)) *httptest.Server { + t.Helper() + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatalf("upgrade: %v", err) + return + } + defer conn.Close() + handler(conn) + })) + return server +} + +func wsURLFromHTTP(httpURL string) string { + return strings.Replace(httpURL, "http://", "ws://", 1) +} + +func TestWSTransport_Connect(t *testing.T) { + server := mockWSServer(t, func(conn *websocket.Conn) { + // Just accept and hold connection + for { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + }) + defer server.Close() + + ws := newWSTransport(wsURLFromHTTP(server.URL), "test-key", nil) + // Override wsURL to point at test server + ws.baseURL = wsURLFromHTTP(server.URL) + + err := ws.connect(context.Background()) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer ws.Close() + + if ws.conn == nil { + t.Fatal("expected conn to be set") + } + if ws.connectedAt.IsZero() { + t.Fatal("expected connectedAt to be set") + } +} + +func TestWSTransport_EnsureConnected_Reconnect(t *testing.T) { + var connectCount int + var mu sync.Mutex + + server := mockWSServer(t, func(conn *websocket.Conn) { + mu.Lock() + connectCount++ + mu.Unlock() + for { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + }) + defer server.Close() + + ws := newWSTransport(wsURLFromHTTP(server.URL), "test-key", nil) + ws.baseURL = wsURLFromHTTP(server.URL) + + // First connect + err := ws.ensureConnected(context.Background()) + if err != nil { + t.Fatalf("first connect: %v", err) + } + + // Should not reconnect (within threshold) + err = ws.ensureConnected(context.Background()) + if err != nil { + t.Fatalf("second connect: %v", err) + } + + mu.Lock() + if connectCount != 1 { + t.Fatalf("expected 1 connection, got %d", connectCount) + } + mu.Unlock() + + // Simulate expired connection by backdating connectedAt + ws.connectedAt = time.Now().Add(-56 * time.Minute) + + err = ws.ensureConnected(context.Background()) + if err != nil { + t.Fatalf("reconnect: %v", err) + } + defer ws.Close() + + // Wait a moment for the server to register the new connection + time.Sleep(50 * time.Millisecond) + + mu.Lock() + if connectCount != 2 { + t.Fatalf("expected 2 connections after reconnect, got %d", connectCount) + } + mu.Unlock() +} + +func TestWSTransport_SendResponseCreate(t *testing.T) { + server := mockWSServer(t, func(conn *websocket.Conn) { + // Read the response.create message + _, message, err := conn.ReadMessage() + if err != nil { + t.Errorf("read: %v", err) + return + } + + var evt map[string]json.RawMessage + if err := json.Unmarshal(message, &evt); err != nil { + t.Errorf("unmarshal: %v", err) + return + } + + var eventType string + json.Unmarshal(evt["type"], &eventType) + if eventType != "response.create" { + t.Errorf("expected type response.create, got %s", eventType) + return + } + + // Send response events + events := []string{ + `{"type":"response.created","response":{"id":"resp_123","status":"in_progress"}}`, + `{"type":"response.output_item.added","output_index":0,"item":{"id":"item_1","type":"message","role":"assistant","content":[]}}`, + `{"type":"response.output_text.delta","output_index":0,"content_index":0,"item_id":"item_1","delta":"Hello"}`, + `{"type":"response.output_text.delta","output_index":0,"content_index":0,"item_id":"item_1","delta":" world"}`, + `{"type":"response.completed","response":{"id":"resp_123","status":"completed","output":[{"id":"item_1","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello world"}]}],"usage":{"input_tokens":10,"output_tokens":5}}}`, + } + for _, event := range events { + if err := conn.WriteMessage(websocket.TextMessage, []byte(event)); err != nil { + return + } + } + }) + defer server.Close() + + ws := newWSTransport(wsURLFromHTTP(server.URL), "test-key", nil) + ws.baseURL = wsURLFromHTTP(server.URL) + + ws.mu.Lock() + defer ws.mu.Unlock() + + body := json.RawMessage(`{"model":"gpt-4o","input":[]}`) + events, err := ws.sendResponseCreate(context.Background(), body) + if err != nil { + t.Fatalf("sendResponseCreate: %v", err) + } + + var eventTypes []string + for evt := range events { + eventTypes = append(eventTypes, evt.Type) + } + + expected := []string{ + "response.created", + "response.output_item.added", + "response.output_text.delta", + "response.output_text.delta", + "response.completed", + } + + if len(eventTypes) != len(expected) { + t.Fatalf("expected %d events, got %d: %v", len(expected), len(eventTypes), eventTypes) + } + + for i, eventType := range eventTypes { + if eventType != expected[i] { + t.Errorf("event %d: expected %s, got %s", i, expected[i], eventType) + } + } +} + +func TestWSTransport_PreviousResponseIDChaining(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + ws.lastResponseID = "resp_prev_123" + + call := fantasy.Call{ + ProviderOptions: fantasy.ProviderOptions{}, + } + + body := json.RawMessage(`{"model":"gpt-4o","input":[]}`) + result := ws.applyWSOptions(body, call) + + var resultMap map[string]json.RawMessage + if err := json.Unmarshal(result, &resultMap); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + + prevIDRaw, ok := resultMap["previous_response_id"] + if !ok { + t.Fatal("expected previous_response_id in result") + } + + var prevID string + json.Unmarshal(prevIDRaw, &prevID) + if prevID != "resp_prev_123" { + t.Errorf("expected resp_prev_123, got %s", prevID) + } +} + +func TestWSTransport_GenerateWarmup(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + + warmup := true + call := fantasy.Call{ + ProviderOptions: fantasy.ProviderOptions{ + Name: &ResponsesProviderOptions{ + GenerateWarmup: &warmup, + }, + }, + } + + body := json.RawMessage(`{"model":"gpt-4o","input":[]}`) + result := ws.applyWSOptions(body, call) + + var resultMap map[string]json.RawMessage + if err := json.Unmarshal(result, &resultMap); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + + genRaw, ok := resultMap["generate"] + if !ok { + t.Fatal("expected generate in result") + } + + if string(genRaw) != "false" { + t.Errorf("expected generate=false, got %s", string(genRaw)) + } +} + +func TestWSTransport_ExplicitPreviousResponseIDOverridesAuto(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + ws.lastResponseID = "resp_auto_123" + + // Set explicit previous_response_id in the body + body := json.RawMessage(`{"model":"gpt-4o","input":[],"previous_response_id":"resp_explicit_456"}`) + call := fantasy.Call{} + result := ws.applyWSOptions(body, call) + + var resultMap map[string]json.RawMessage + if err := json.Unmarshal(result, &resultMap); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + + var prevID string + json.Unmarshal(resultMap["previous_response_id"], &prevID) + if prevID != "resp_explicit_456" { + t.Errorf("expected explicit ID resp_explicit_456, got %s", prevID) + } +} + +func TestWSTransport_FallbackToHTTP(t *testing.T) { + // Create a provider with WebSocket enabled but no server running + provider, err := New( + WithAPIKey("test-key"), + WithBaseURL("https://localhost:1"), + WithUseResponsesAPI(), + WithWebSocket(), + ) + if err != nil { + t.Fatalf("new provider: %v", err) + } + + model, err := provider.LanguageModel(context.Background(), "gpt-4o") + if err != nil { + t.Fatalf("language model: %v", err) + } + + // The model should be a responsesLanguageModel with wsTransport set + rlm, ok := model.(responsesLanguageModel) + if !ok { + t.Fatal("expected responsesLanguageModel") + } + if rlm.wsTransport == nil { + t.Fatal("expected wsTransport to be set") + } +} + +func TestWSURL(t *testing.T) { + tests := []struct { + baseURL string + expected string + }{ + {"https://api.openai.com/v1", "wss://api.openai.com/v1/responses"}, + {"https://custom.api.com/v1", "wss://custom.api.com/v1/responses"}, + {"http://localhost:8080/v1", "ws://localhost:8080/v1/responses"}, + {"https://api.openai.com/v1/", "wss://api.openai.com/v1/responses"}, + } + + for _, tt := range tests { + ws := newWSTransport(tt.baseURL, "key", nil) + ws.baseURL = tt.baseURL + got := ws.wsURL() + if got != tt.expected { + t.Errorf("wsURL(%s) = %s, want %s", tt.baseURL, got, tt.expected) + } + } +} From e6ec3d0bcf87007777bb243d1a0a7779dfadc46d Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Tue, 24 Feb 2026 18:23:20 +0000 Subject: [PATCH 2/6] feat(ws): send only incremental input when chaining with previous_response_id When using WebSocket mode with previous_response_id chaining, the server already has the prior conversation context. Previously we sent the full input array every time, which was redundant and incorrect per the spec. Now wsTransport tracks lastInputLen (the number of input items sent in the last successful request). On subsequent calls with previous_response_id, only new items are sent. Additionally, function_call items are filtered out since the server generated those as part of its own response output. On chain breaks (previous_response_not_found errors), lastInputLen resets to 0 so the next call sends the full prompt. --- providers/openai/responses_websocket.go | 79 ++++++++++++++++- providers/openai/responses_websocket_model.go | 10 ++- providers/openai/responses_websocket_test.go | 87 ++++++++++++++++++- 3 files changed, 167 insertions(+), 9 deletions(-) diff --git a/providers/openai/responses_websocket.go b/providers/openai/responses_websocket.go index ff218a443..a3cd07544 100644 --- a/providers/openai/responses_websocket.go +++ b/providers/openai/responses_websocket.go @@ -25,6 +25,7 @@ type wsTransport struct { apiKey string headers map[string]string lastResponseID string + lastInputLen int // number of input items sent in the last successful request } // newWSTransport creates a new WebSocket transport for the OpenAI Responses API. @@ -190,18 +191,88 @@ func mustMarshal(v any) json.RawMessage { return data } +// extractIncrementalInput returns only the new input items that the server hasn't +// seen yet. When chaining with previous_response_id, the server already has the +// prior context, so we only send function_call_output items and new user messages. +// Items of type "function_call" are filtered out because the server generated those +// as part of its own response output. +func (ws *wsTransport) extractIncrementalInput(fullInput json.RawMessage) (json.RawMessage, int) { + var items []json.RawMessage + if err := json.Unmarshal(fullInput, &items); err != nil { + return fullInput, 0 + } + + fullLen := len(items) + + if ws.lastResponseID == "" || ws.lastInputLen == 0 { + return fullInput, fullLen + } + + if len(items) <= ws.lastInputLen { + return fullInput, fullLen + } + + // Take only items appended since the last request. + newItems := items[ws.lastInputLen:] + + // Filter out function_call items — the server already has these from its + // own response output; sending them again would be redundant. + var incremental []json.RawMessage + for _, item := range newItems { + var parsed map[string]json.RawMessage + if err := json.Unmarshal(item, &parsed); err != nil { + incremental = append(incremental, item) + continue + } + if typeField, ok := parsed["type"]; ok { + var itemType string + if err := json.Unmarshal(typeField, &itemType); err == nil && itemType == "function_call" { + continue + } + } + incremental = append(incremental, item) + } + + result, err := json.Marshal(incremental) + if err != nil { + return fullInput, fullLen + } + return result, fullLen +} + // applyWSOptions modifies the marshaled params JSON to add WebSocket-specific fields // like previous_response_id (from transport state) and generate (for warmup). -func (ws *wsTransport) applyWSOptions(body json.RawMessage, call fantasy.Call) json.RawMessage { +// It returns the modified body and the full input item count (before any trimming) +// so callers can update lastInputLen after a successful response. +func (ws *wsTransport) applyWSOptions(body json.RawMessage, call fantasy.Call) (json.RawMessage, int) { var bodyMap map[string]json.RawMessage if err := json.Unmarshal(body, &bodyMap); err != nil { - return body + return body, 0 } + var fullInputLen int + // Auto-chain with previous_response_id from transport state if not explicitly set + usingPrevID := false if _, hasPrevID := bodyMap["previous_response_id"]; !hasPrevID && ws.lastResponseID != "" { prevIDBytes, _ := json.Marshal(ws.lastResponseID) bodyMap["previous_response_id"] = prevIDBytes + usingPrevID = true + } else if _, hasPrevID := bodyMap["previous_response_id"]; hasPrevID { + usingPrevID = true + } + + // When chaining, send only incremental input items. + if inputField, hasInput := bodyMap["input"]; hasInput { + if usingPrevID && ws.lastInputLen > 0 { + bodyMap["input"], fullInputLen = ws.extractIncrementalInput(inputField) + } else { + // Count full input items for tracking. + var items []json.RawMessage + if err := json.Unmarshal(inputField, &items); err == nil { + fullInputLen = len(items) + } + } } // Handle GenerateWarmup from provider options @@ -217,7 +288,7 @@ func (ws *wsTransport) applyWSOptions(body json.RawMessage, call fantasy.Call) j result, err := json.Marshal(bodyMap) if err != nil { - return body + return body, fullInputLen } - return result + return result, fullInputLen } diff --git a/providers/openai/responses_websocket_model.go b/providers/openai/responses_websocket_model.go index 66cafc63f..3ef453091 100644 --- a/providers/openai/responses_websocket_model.go +++ b/providers/openai/responses_websocket_model.go @@ -22,7 +22,8 @@ func (o responsesLanguageModel) generateViaWebSocket(ctx context.Context, params return nil, fmt.Errorf("marshal params: %w", err) } - body = o.wsTransport.applyWSOptions(body, call) + var fullInputLen int + body, fullInputLen = o.wsTransport.applyWSOptions(body, call) events, err := o.wsTransport.sendResponseCreate(ctx, body) if err != nil { @@ -44,6 +45,7 @@ func (o responsesLanguageModel) generateViaWebSocket(ctx context.Context, params case "response.completed", "response.incomplete": completed := streamEvent.AsResponseCompleted() o.wsTransport.lastResponseID = completed.Response.ID + o.wsTransport.lastInputLen = fullInputLen // Build content from the completed response output content = nil // Reset — use the final response @@ -134,6 +136,7 @@ func (o responsesLanguageModel) generateViaWebSocket(ctx context.Context, params errorEvent := streamEvent.AsError() if errorEvent.Code == "previous_response_not_found" { o.wsTransport.lastResponseID = "" + o.wsTransport.lastInputLen = 0 return nil, fmt.Errorf("previous_response_not_found") } responseErr = fmt.Errorf("%s (code: %s)", errorEvent.Message, errorEvent.Code) @@ -169,7 +172,8 @@ func (o responsesLanguageModel) streamViaWebSocket(ctx context.Context, params * return nil, fmt.Errorf("marshal params: %w", err) } - body = o.wsTransport.applyWSOptions(body, call) + var fullInputLen int + body, fullInputLen = o.wsTransport.applyWSOptions(body, call) events, err := o.wsTransport.sendResponseCreate(ctx, body) if err != nil { @@ -355,6 +359,7 @@ func (o responsesLanguageModel) streamViaWebSocket(ctx context.Context, params * case "response.completed", "response.incomplete": completed := event.AsResponseCompleted() o.wsTransport.lastResponseID = completed.Response.ID + o.wsTransport.lastInputLen = fullInputLen finishReason = mapResponsesFinishReason(completed.Response.IncompleteDetails.Reason, hasFunctionCall) usage = fantasy.Usage{ InputTokens: completed.Response.Usage.InputTokens, @@ -372,6 +377,7 @@ func (o responsesLanguageModel) streamViaWebSocket(ctx context.Context, params * errorEvent := event.AsError() if errorEvent.Code == "previous_response_not_found" { o.wsTransport.lastResponseID = "" + o.wsTransport.lastInputLen = 0 } if !yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, diff --git a/providers/openai/responses_websocket_test.go b/providers/openai/responses_websocket_test.go index 6c6711029..7851e4f94 100644 --- a/providers/openai/responses_websocket_test.go +++ b/providers/openai/responses_websocket_test.go @@ -206,7 +206,7 @@ func TestWSTransport_PreviousResponseIDChaining(t *testing.T) { } body := json.RawMessage(`{"model":"gpt-4o","input":[]}`) - result := ws.applyWSOptions(body, call) + result, _ := ws.applyWSOptions(body, call) var resultMap map[string]json.RawMessage if err := json.Unmarshal(result, &resultMap); err != nil { @@ -238,7 +238,7 @@ func TestWSTransport_GenerateWarmup(t *testing.T) { } body := json.RawMessage(`{"model":"gpt-4o","input":[]}`) - result := ws.applyWSOptions(body, call) + result, _ := ws.applyWSOptions(body, call) var resultMap map[string]json.RawMessage if err := json.Unmarshal(result, &resultMap); err != nil { @@ -262,7 +262,7 @@ func TestWSTransport_ExplicitPreviousResponseIDOverridesAuto(t *testing.T) { // Set explicit previous_response_id in the body body := json.RawMessage(`{"model":"gpt-4o","input":[],"previous_response_id":"resp_explicit_456"}`) call := fantasy.Call{} - result := ws.applyWSOptions(body, call) + result, _ := ws.applyWSOptions(body, call) var resultMap map[string]json.RawMessage if err := json.Unmarshal(result, &resultMap); err != nil { @@ -303,6 +303,87 @@ func TestWSTransport_FallbackToHTTP(t *testing.T) { } } +func TestWSTransport_IncrementalInput(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + ws.lastResponseID = "resp_001" + ws.lastInputLen = 3 + + // Simulate full input: 3 old items + 2 new (1 function_call + 1 function_call_output). + // The function_call should be filtered out since the server generated it. + fullInput := json.RawMessage(`[ + {"type":"message","role":"system","content":"sys"}, + {"type":"message","role":"user","content":"hello"}, + {"type":"message","role":"assistant","content":"hi"}, + {"type":"function_call","call_id":"call_1","name":"search","arguments":"{}"}, + {"type":"function_call_output","call_id":"call_1","output":"result"} + ]`) + + got, fullLen := ws.extractIncrementalInput(fullInput) + if fullLen != 5 { + t.Fatalf("expected fullLen=5, got %d", fullLen) + } + + var items []map[string]interface{} + if err := json.Unmarshal(got, &items); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + + if len(items) != 1 { + t.Fatalf("expected 1 incremental item, got %d: %s", len(items), string(got)) + } + if items[0]["type"] != "function_call_output" { + t.Errorf("expected function_call_output, got %v", items[0]["type"]) + } +} + +func TestWSTransport_IncrementalInput_FirstCall(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + // No lastResponseID — should return full input. + + fullInput := json.RawMessage(`[{"type":"message","role":"user","content":"hi"}]`) + got, fullLen := ws.extractIncrementalInput(fullInput) + if fullLen != 1 { + t.Fatalf("expected fullLen=1, got %d", fullLen) + } + if string(got) != string(fullInput) { + t.Errorf("expected full input returned on first call") + } +} + +func TestWSTransport_IncrementalInput_AppliedViaApplyWSOptions(t *testing.T) { + ws := newWSTransport("wss://api.openai.com/v1", "test-key", nil) + ws.lastResponseID = "resp_001" + ws.lastInputLen = 2 + + body := json.RawMessage(`{"model":"gpt-4o","input":[ + {"type":"message","role":"user","content":"hello"}, + {"type":"message","role":"assistant","content":"hi"}, + {"type":"function_call_output","call_id":"call_1","output":"42"} + ]}`) + + result, fullLen := ws.applyWSOptions(body, fantasy.Call{}) + if fullLen != 3 { + t.Fatalf("expected fullLen=3, got %d", fullLen) + } + + var resultMap map[string]json.RawMessage + if err := json.Unmarshal(result, &resultMap); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + var items []map[string]interface{} + if err := json.Unmarshal(resultMap["input"], &items); err != nil { + t.Fatalf("unmarshal input: %v", err) + } + + if len(items) != 1 { + t.Fatalf("expected 1 incremental input item, got %d", len(items)) + } + if items[0]["type"] != "function_call_output" { + t.Errorf("expected function_call_output, got %v", items[0]["type"]) + } +} + func TestWSURL(t *testing.T) { tests := []struct { baseURL string From 0f2292dc156c45bfb9e9f5e01fe5c7e8f90a71fb Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Tue, 24 Feb 2026 18:26:39 +0000 Subject: [PATCH 3/6] fix(ws): handle response.failed events, fix goroutine leak, use mapResponsesFinishReason - Handle response.failed events in both generateViaWebSocket and streamViaWebSocket (previously would return nil error with empty content) - Fix read goroutine leak on context cancellation by setting a read deadline when ctx is done, allowing ReadMessage to unblock - Use mapResponsesFinishReason in generateViaWebSocket to match the HTTP path's handling of incomplete_details reasons --- providers/openai/responses_websocket.go | 19 ++++++++++++++---- providers/openai/responses_websocket_model.go | 20 +++++++++++++++---- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/providers/openai/responses_websocket.go b/providers/openai/responses_websocket.go index a3cd07544..a51ddd7a7 100644 --- a/providers/openai/responses_websocket.go +++ b/providers/openai/responses_websocket.go @@ -147,17 +147,28 @@ func (ws *wsTransport) sendResponseCreate(ctx context.Context, body json.RawMess events := make(chan wsServerEvent, 64) + conn := ws.conn go func() { defer close(events) - for { + + // Set a read deadline when the context is cancelled to unblock ReadMessage. + done := make(chan struct{}) + defer close(done) + go func() { select { case <-ctx.Done(): - return - default: + conn.SetReadDeadline(time.Now()) + case <-done: } + }() - _, message, err := ws.conn.ReadMessage() + for { + _, message, err := conn.ReadMessage() if err != nil { + // Don't emit an error event if the context was cancelled. + if ctx.Err() != nil { + return + } events <- wsServerEvent{ Type: "error", Raw: mustMarshal(map[string]string{"type": "error", "code": "websocket_read_error", "message": err.Error()}), diff --git a/providers/openai/responses_websocket_model.go b/providers/openai/responses_websocket_model.go index 3ef453091..1d33ea676 100644 --- a/providers/openai/responses_websocket_model.go +++ b/providers/openai/responses_websocket_model.go @@ -132,6 +132,11 @@ func (o responsesLanguageModel) generateViaWebSocket(ctx context.Context, params usage.CacheReadTokens = completed.Response.Usage.InputTokensDetails.CachedTokens } + case "response.failed": + completed := streamEvent.AsResponseCompleted() + responseErr = fmt.Errorf("response failed: %s (code: %s)", + completed.Response.Error.Message, completed.Response.Error.Code) + case "error": errorEvent := streamEvent.AsError() if errorEvent.Code == "previous_response_not_found" { @@ -147,10 +152,7 @@ func (o responsesLanguageModel) generateViaWebSocket(ctx context.Context, params return nil, responseErr } - finishReason := fantasy.FinishReasonStop - if hasFunctionCall { - finishReason = fantasy.FinishReasonToolCalls - } + finishReason := mapResponsesFinishReason("", hasFunctionCall) return &fantasy.Response{ Content: content, @@ -373,6 +375,16 @@ func (o responsesLanguageModel) streamViaWebSocket(ctx context.Context, params * usage.CacheReadTokens = completed.Response.Usage.InputTokensDetails.CachedTokens } + case "response.failed": + completed := event.AsResponseCompleted() + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: fmt.Errorf("response failed: %s (code: %s)", completed.Response.Error.Message, completed.Response.Error.Code), + }) { + return + } + return + case "error": errorEvent := event.AsError() if errorEvent.Code == "previous_response_not_found" { From 0fa62ddd871cd4d89f5e6ac0b6e13cbc6eb87a23 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Tue, 24 Feb 2026 20:48:15 +0000 Subject: [PATCH 4/6] fix(openai): close response body from websocket dial and check conn.Close error --- providers/openai/responses_websocket.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/providers/openai/responses_websocket.go b/providers/openai/responses_websocket.go index a51ddd7a7..ecfcf8a9b 100644 --- a/providers/openai/responses_websocket.go +++ b/providers/openai/responses_websocket.go @@ -60,7 +60,10 @@ func (ws *wsTransport) connect(ctx context.Context) error { HandshakeTimeout: 30 * time.Second, } - conn, _, err := dialer.DialContext(ctx, ws.wsURL(), header) + conn, resp, err := dialer.DialContext(ctx, ws.wsURL(), header) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } if err != nil { return fmt.Errorf("websocket connect: %w", err) } @@ -77,7 +80,7 @@ func (ws *wsTransport) ensureConnected(ctx context.Context) error { } if ws.conn != nil { - ws.conn.Close() + _ = ws.conn.Close() ws.conn = nil } From 146da14a8cff26a073a1d017a8c16bb326ee7ed3 Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Tue, 24 Feb 2026 23:02:37 +0000 Subject: [PATCH 5/6] fix(openai): check error returns for resp.Body.Close and SetReadDeadline --- providers/openai/responses_websocket.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/openai/responses_websocket.go b/providers/openai/responses_websocket.go index ecfcf8a9b..ea68a45e7 100644 --- a/providers/openai/responses_websocket.go +++ b/providers/openai/responses_websocket.go @@ -62,7 +62,7 @@ func (ws *wsTransport) connect(ctx context.Context) error { conn, resp, err := dialer.DialContext(ctx, ws.wsURL(), header) if resp != nil && resp.Body != nil { - resp.Body.Close() + _ = resp.Body.Close() } if err != nil { return fmt.Errorf("websocket connect: %w", err) @@ -160,7 +160,7 @@ func (ws *wsTransport) sendResponseCreate(ctx context.Context, body json.RawMess go func() { select { case <-ctx.Done(): - conn.SetReadDeadline(time.Now()) + _ = conn.SetReadDeadline(time.Now()) case <-done: } }() From 7a27ae7b5c1b5a1228a1cfc51fdf59b44cd0524b Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Wed, 25 Feb 2026 15:58:32 -0500 Subject: [PATCH 6/6] add reset options --- providers/openai/responses_options.go | 1 + providers/openai/responses_websocket.go | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/providers/openai/responses_options.go b/providers/openai/responses_options.go index 64f8abdcc..2799a0b89 100644 --- a/providers/openai/responses_options.go +++ b/providers/openai/responses_options.go @@ -107,6 +107,7 @@ type ResponsesProviderOptions struct { PromptCacheKey *string `json:"prompt_cache_key"` ReasoningEffort *ReasoningEffort `json:"reasoning_effort"` ReasoningSummary *string `json:"reasoning_summary"` + ResetChain *bool `json:"reset_chain"` SafetyIdentifier *string `json:"safety_identifier"` ServiceTier *ServiceTier `json:"service_tier"` StrictJSONSchema *bool `json:"strict_json_schema"` diff --git a/providers/openai/responses_websocket.go b/providers/openai/responses_websocket.go index ea68a45e7..8b5e2e854 100644 --- a/providers/openai/responses_websocket.go +++ b/providers/openai/responses_websocket.go @@ -264,6 +264,18 @@ func (ws *wsTransport) applyWSOptions(body json.RawMessage, call fantasy.Call) ( return body, 0 } + // Handle ResetChain from provider options - must be checked before auto-chaining. + var openaiOptions *ResponsesProviderOptions + if opts, ok := call.ProviderOptions[Name]; ok { + if typedOpts, ok := opts.(*ResponsesProviderOptions); ok { + openaiOptions = typedOpts + } + } + if openaiOptions != nil && openaiOptions.ResetChain != nil && *openaiOptions.ResetChain { + ws.lastResponseID = "" + ws.lastInputLen = 0 + } + var fullInputLen int // Auto-chain with previous_response_id from transport state if not explicitly set @@ -289,13 +301,7 @@ func (ws *wsTransport) applyWSOptions(body json.RawMessage, call fantasy.Call) ( } } - // Handle GenerateWarmup from provider options - var openaiOptions *ResponsesProviderOptions - if opts, ok := call.ProviderOptions[Name]; ok { - if typedOpts, ok := opts.(*ResponsesProviderOptions); ok { - openaiOptions = typedOpts - } - } + // Handle GenerateWarmup from provider options. if openaiOptions != nil && openaiOptions.GenerateWarmup != nil && *openaiOptions.GenerateWarmup { bodyMap["generate"] = json.RawMessage("false") }