Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion providers/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type options struct {
project string
name string
useResponsesAPI bool
useWebSocket bool
headers map[string]string
client option.HTTPClient
sdkOptions []option.RequestOption
Expand Down Expand Up @@ -132,6 +133,15 @@ func WithUseResponsesAPI() Option {
}
}

// WithWebSocket enables WebSocket mode for the Responses API, providing lower-latency
// persistent connections for tool-call-heavy workflows. Requires WithUseResponsesAPI().
// Falls back to HTTP if the WebSocket connection cannot be established.
func WithWebSocket() Option {
return func(o *options) {
o.useWebSocket = true
}
}

// WithObjectMode sets the object generation mode.
func WithObjectMode(om fantasy.ObjectMode) Option {
return func(o *options) {
Expand Down Expand Up @@ -173,7 +183,11 @@ func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.Lan
if objectMode == fantasy.ObjectModeJSON {
objectMode = fantasy.ObjectModeAuto
}
return newResponsesLanguageModel(modelID, o.options.name, client, objectMode), nil
var ws *wsTransport
if o.options.useWebSocket {
ws = newWSTransport(o.options.baseURL, o.options.apiKey, o.options.headers)
}
return newResponsesLanguageModel(modelID, o.options.name, client, objectMode, ws), nil
}

o.options.languageModelOptions = append(o.options.languageModelOptions, WithLanguageModelObjectMode(o.options.objectMode))
Expand Down
40 changes: 31 additions & 9 deletions providers/openai/responses_language_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,22 @@ import (
const topLogprobsMax = 20

type responsesLanguageModel struct {
provider string
modelID string
client openai.Client
objectMode fantasy.ObjectMode
provider string
modelID string
client openai.Client
objectMode fantasy.ObjectMode
wsTransport *wsTransport
}

// newResponsesLanguageModel implements a responses api model
// INFO: (kujtim) currently we do not support stored parameter we default it to false.
func newResponsesLanguageModel(modelID string, provider string, client openai.Client, objectMode fantasy.ObjectMode) responsesLanguageModel {
func newResponsesLanguageModel(modelID string, provider string, client openai.Client, objectMode fantasy.ObjectMode, ws *wsTransport) responsesLanguageModel {
return responsesLanguageModel{
modelID: modelID,
provider: provider,
client: client,
objectMode: objectMode,
modelID: modelID,
provider: provider,
client: client,
objectMode: objectMode,
wsTransport: ws,
}
}

Expand Down Expand Up @@ -228,6 +230,9 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res
if openaiOptions.SafetyIdentifier != nil {
params.SafetyIdentifier = param.NewOpt(*openaiOptions.SafetyIdentifier)
}
if openaiOptions.PreviousResponseID != nil {
params.PreviousResponseID = param.NewOpt(*openaiOptions.PreviousResponseID)
}
if topLogprobs > 0 {
params.TopLogprobs = param.NewOpt(int64(topLogprobs))
}
Expand Down Expand Up @@ -668,6 +673,15 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti

func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
params, warnings := o.prepareParams(call)

if o.wsTransport != nil {
resp, err := o.generateViaWebSocket(ctx, params, warnings, call)
if err == nil {
return resp, nil
}
// Fall back to HTTP on WebSocket failure
}

response, err := o.client.Responses.New(ctx, *params)
if err != nil {
return nil, toProviderErr(err)
Expand Down Expand Up @@ -806,6 +820,14 @@ func mapResponsesFinishReason(reason string, hasFunctionCall bool) fantasy.Finis
func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
params, warnings := o.prepareParams(call)

if o.wsTransport != nil {
resp, err := o.streamViaWebSocket(ctx, params, warnings, call)
if err == nil {
return resp, nil
}
// Fall back to HTTP on WebSocket failure
}

stream := o.client.Responses.NewStreaming(ctx, *params)

finishReason := fantasy.FinishReasonUnknown
Expand Down
31 changes: 17 additions & 14 deletions providers/openai/responses_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,23 @@ const (

// ResponsesProviderOptions represents additional options for OpenAI Responses API.
type ResponsesProviderOptions struct {
Include []IncludeType `json:"include"`
Instructions *string `json:"instructions"`
Logprobs any `json:"logprobs"`
MaxToolCalls *int64 `json:"max_tool_calls"`
Metadata map[string]any `json:"metadata"`
ParallelToolCalls *bool `json:"parallel_tool_calls"`
PromptCacheKey *string `json:"prompt_cache_key"`
ReasoningEffort *ReasoningEffort `json:"reasoning_effort"`
ReasoningSummary *string `json:"reasoning_summary"`
SafetyIdentifier *string `json:"safety_identifier"`
ServiceTier *ServiceTier `json:"service_tier"`
StrictJSONSchema *bool `json:"strict_json_schema"`
TextVerbosity *TextVerbosity `json:"text_verbosity"`
User *string `json:"user"`
Include []IncludeType `json:"include"`
Instructions *string `json:"instructions"`
Logprobs any `json:"logprobs"`
MaxToolCalls *int64 `json:"max_tool_calls"`
Metadata map[string]any `json:"metadata"`
ParallelToolCalls *bool `json:"parallel_tool_calls"`
PreviousResponseID *string `json:"previous_response_id"`
PromptCacheKey *string `json:"prompt_cache_key"`
ReasoningEffort *ReasoningEffort `json:"reasoning_effort"`
ReasoningSummary *string `json:"reasoning_summary"`
ResetChain *bool `json:"reset_chain"`
SafetyIdentifier *string `json:"safety_identifier"`
ServiceTier *ServiceTier `json:"service_tier"`
StrictJSONSchema *bool `json:"strict_json_schema"`
TextVerbosity *TextVerbosity `json:"text_verbosity"`
User *string `json:"user"`
GenerateWarmup *bool `json:"generate_warmup"`
}

// Options implements the ProviderOptions interface.
Expand Down
Loading
Loading