diff --git a/Taskfile.yaml b/Taskfile.yaml index 34455b902..6ae12c518 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -37,7 +37,9 @@ tasks: - sh: "[ $(git status --porcelain=2 | wc -l) = 0 ]" msg: "Git is dirty" cmds: - - git commit --allow-empty -m "{{.NEXT}}" + - echo {{trimPrefix "v" .NEXT}} > version.txt + - git add version.txt + - git commit -m "{{.NEXT}}" - git tag --annotate --sign -m "{{.NEXT}}" {{.NEXT}} {{.CLI_ARGS}} - echo "Pushing {{.NEXT}}..." - git push origin main --follow-tags diff --git a/agent.go b/agent.go index 426beff16..6d2d62dfb 100644 --- a/agent.go +++ b/agent.go @@ -138,6 +138,7 @@ type agentSettings struct { presencePenalty *float64 frequencyPenalty *float64 headers map[string]string + userAgent string providerOptions ProviderOptions // TODO: add support for provider tools @@ -448,6 +449,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err FrequencyPenalty: opts.FrequencyPenalty, Tools: preparedTools, ToolChoice: &stepToolChoice, + UserAgent: a.settings.userAgent, ProviderOptions: opts.ProviderOptions, }) }) @@ -829,6 +831,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, FrequencyPenalty: call.FrequencyPenalty, Tools: preparedTools, ToolChoice: &stepToolChoice, + UserAgent: a.settings.userAgent, ProviderOptions: call.ProviderOptions, } @@ -1418,6 +1421,14 @@ func WithHeaders(headers map[string]string) AgentOption { } } +// WithUserAgent sets the User-Agent header for the agent. This overrides any +// provider-level User-Agent setting. +func WithUserAgent(ua string) AgentOption { + return func(s *agentSettings) { + s.userAgent = ua + } +} + // WithProviderOptions sets the provider options for the agent. func WithProviderOptions(providerOptions ProviderOptions) AgentOption { return func(s *agentSettings) { diff --git a/agent_useragent_test.go b/agent_useragent_test.go new file mode 100644 index 000000000..95cc81d8e --- /dev/null +++ b/agent_useragent_test.go @@ -0,0 +1,71 @@ +package fantasy + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAgent_WithUserAgent_PropagatesOnGenerate(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + capturedCall = call + return &Response{ + Content: []Content{TextContent{Text: "ok"}}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithUserAgent("MyApp/2.0")) + _, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Equal(t, "MyApp/2.0", capturedCall.UserAgent) +} + +func TestAgent_WithUserAgent_PropagatesOnStream(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + streamFunc: func(_ context.Context, call Call) (StreamResponse, error) { + capturedCall = call + return func(yield func(StreamPart) bool) { + yield(StreamPart{ + Type: StreamPartTypeFinish, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(model, WithUserAgent("StreamApp/1.0")) + _, err := agent.Stream(context.Background(), AgentStreamCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Equal(t, "StreamApp/1.0", capturedCall.UserAgent) +} + +func TestAgent_NoUA_OmitsCallLevelFields(t *testing.T) { + t.Parallel() + + var capturedCall Call + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + capturedCall = call + return &Response{ + Content: []Content{TextContent{Text: "ok"}}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model) + _, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"}) + require.NoError(t, err) + assert.Empty(t, capturedCall.UserAgent) +} diff --git a/model.go b/model.go index 4d1c3e31b..16980da10 100644 --- a/model.go +++ b/model.go @@ -218,6 +218,9 @@ type Call struct { Tools []Tool `json:"tools"` ToolChoice *ToolChoice `json:"tool_choice"` + // UserAgent overrides the provider-level User-Agent header for this call. + UserAgent string `json:"-"` + // for provider specific options, the key is the provider id ProviderOptions ProviderOptions `json:"provider_options"` } diff --git a/object.go b/object.go index 4b8aed369..3cbcbd3eb 100644 --- a/object.go +++ b/object.go @@ -41,6 +41,9 @@ type ObjectCall struct { PresencePenalty *float64 FrequencyPenalty *float64 + // UserAgent overrides the provider-level User-Agent header for this call. + UserAgent string `json:"-"` + ProviderOptions ProviderOptions RepairText schema.ObjectRepairFunc diff --git a/object/object.go b/object/object.go index c62b1cbec..16c77d1cf 100644 --- a/object/object.go +++ b/object/object.go @@ -120,6 +120,7 @@ func GenerateWithTool( TopK: call.TopK, PresencePenalty: call.PresencePenalty, FrequencyPenalty: call.FrequencyPenalty, + UserAgent: call.UserAgent, ProviderOptions: call.ProviderOptions, }) if err != nil { @@ -212,6 +213,7 @@ func GenerateWithText( TopK: call.TopK, PresencePenalty: call.PresencePenalty, FrequencyPenalty: call.FrequencyPenalty, + UserAgent: call.UserAgent, ProviderOptions: call.ProviderOptions, }) if err != nil { @@ -294,6 +296,7 @@ func StreamWithTool( TopK: call.TopK, PresencePenalty: call.PresencePenalty, FrequencyPenalty: call.FrequencyPenalty, + UserAgent: call.UserAgent, ProviderOptions: call.ProviderOptions, }) if err != nil { @@ -503,6 +506,7 @@ func StreamWithText( TopK: call.TopK, PresencePenalty: call.PresencePenalty, FrequencyPenalty: call.FrequencyPenalty, + UserAgent: call.UserAgent, ProviderOptions: call.ProviderOptions, }) if err != nil { diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index 4b593408d..ca09ab3f8 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -14,6 +14,7 @@ import ( "charm.land/fantasy" "charm.land/fantasy/object" + "charm.land/fantasy/providers/internal/httpheaders" "github.com/aws/aws-sdk-go-v2/config" "github.com/charmbracelet/anthropic-sdk-go" "github.com/charmbracelet/anthropic-sdk-go/bedrock" @@ -31,11 +32,12 @@ const ( ) type options struct { - baseURL string - apiKey string - name string - headers map[string]string - client option.HTTPClient + baseURL string + apiKey string + name string + headers map[string]string + userAgent string + client option.HTTPClient vertexProject string vertexLocation string @@ -125,6 +127,14 @@ func WithHTTPClient(client option.HTTPClient) Option { } } +// WithUserAgent sets an explicit User-Agent header, overriding the default and any +// value set via WithHeaders. +func WithUserAgent(ua string) Option { + return func(o *options) { + o.userAgent = ua + } +} + // WithObjectMode sets the object generation mode. func WithObjectMode(om fantasy.ObjectMode) Option { return func(o *options) { @@ -146,7 +156,9 @@ func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.L if a.options.baseURL != "" { clientOptions = append(clientOptions, option.WithBaseURL(a.options.baseURL)) } - for key, value := range a.options.headers { + defaultUA := httpheaders.DefaultUserAgent(fantasy.Version) + resolved := httpheaders.ResolveHeaders(a.options.headers, a.options.userAgent, defaultUA) + for key, value := range resolved { clientOptions = append(clientOptions, option.WithHeader(key, value)) } if a.options.client != nil { @@ -771,7 +783,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas if err != nil { return nil, err } - response, err := a.client.Messages.New(ctx, *params) + response, err := a.client.Messages.New(ctx, *params, callUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } @@ -849,7 +861,7 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S return nil, err } - stream := a.client.Messages.NewStreaming(ctx, *params) + stream := a.client.Messages.NewStreaming(ctx, *params, callUARequestOptions(call)...) acc := anthropic.Message{} return func(yield func(fantasy.StreamPart) bool) { if len(warnings) > 0 { diff --git a/providers/anthropic/call_useragent.go b/providers/anthropic/call_useragent.go new file mode 100644 index 000000000..4aaffcf87 --- /dev/null +++ b/providers/anthropic/call_useragent.go @@ -0,0 +1,14 @@ +package anthropic + +import ( + "charm.land/fantasy" + "charm.land/fantasy/providers/internal/httpheaders" + "github.com/charmbracelet/anthropic-sdk-go/option" +) + +func callUARequestOptions(call fantasy.Call) []option.RequestOption { + if ua, ok := httpheaders.CallUserAgent(call.UserAgent); ok { + return []option.RequestOption{option.WithHeader("User-Agent", ua)} + } + return nil +} diff --git a/providers/anthropic/useragent_test.go b/providers/anthropic/useragent_test.go new file mode 100644 index 000000000..6ae7a4ec1 --- /dev/null +++ b/providers/anthropic/useragent_test.go @@ -0,0 +1,73 @@ +package anthropic + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserAgent(t *testing.T) { + t.Parallel() + + newUAServer := func() (*httptest.Server, *[]map[string]string) { + var captured []map[string]string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := make(map[string]string) + for k, v := range r.Header { + if len(v) > 0 { + h[k] = v[0] + } + } + captured = append(captured, h) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse()) + })) + return server, &captured + } + + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}}, + }, + } + + t.Run("default UA applied", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New(WithAPIKey("k"), WithBaseURL(server.URL)) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "claude-sonnet-4-20250514") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "Charm Fantasy/"+fantasy.Version, (*captured)[0]["User-Agent"]) + }) + + t.Run("WithUserAgent wins over both", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.URL), + WithHeaders(map[string]string{"User-Agent": "from-headers"}), + WithUserAgent("explicit-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "claude-sonnet-4-20250514") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"]) + }) +} diff --git a/providers/azure/azure.go b/providers/azure/azure.go index 68adc2f7a..a68df7251 100644 --- a/providers/azure/azure.go +++ b/providers/azure/azure.go @@ -109,6 +109,14 @@ func WithHTTPClient(client option.HTTPClient) Option { } } +// WithUserAgent sets an explicit User-Agent header, overriding the default and any +// value set via WithHeaders. +func WithUserAgent(ua string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithUserAgent(ua)) + } +} + // WithUseResponsesAPI configures the provider to use the responses API for models that support it. func WithUseResponsesAPI() Option { return func(o *options) { diff --git a/providers/azure/useragent_test.go b/providers/azure/useragent_test.go new file mode 100644 index 000000000..c190f5db7 --- /dev/null +++ b/providers/azure/useragent_test.go @@ -0,0 +1,111 @@ +package azure + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserAgent(t *testing.T) { + t.Parallel() + + newUAServer := func() (*httptest.Server, *[]map[string]string) { + var captured []map[string]string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := make(map[string]string) + for k, v := range r.Header { + if len(v) > 0 { + h[k] = v[0] + } + } + captured = append(captured, h) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(mockOpenAIResponse()) + })) + return server, &captured + } + + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}}, + }, + } + + t.Run("default UA applied", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New(WithAPIKey("k"), WithBaseURL(server.URL)) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "Charm Fantasy/"+fantasy.Version, (*captured)[0]["User-Agent"]) + }) + + t.Run("WithUserAgent wins over default", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New(WithAPIKey("k"), WithBaseURL(server.URL), WithUserAgent("explicit-ua")) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"]) + }) + + t.Run("WithUserAgent wins over WithHeaders", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.URL), + WithHeaders(map[string]string{"User-Agent": "from-headers"}), + WithUserAgent("explicit-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"]) + }) +} + +func mockOpenAIResponse() map[string]any { + return map[string]any{ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1711115037, + "model": "gpt-4", + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "Hi there", + }, + "finish_reason": "stop", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 4, + "total_tokens": 6, + "completion_tokens": 2, + }, + } +} diff --git a/providers/bedrock/bedrock.go b/providers/bedrock/bedrock.go index 215021c18..c8889398c 100644 --- a/providers/bedrock/bedrock.go +++ b/providers/bedrock/bedrock.go @@ -57,6 +57,14 @@ func WithHTTPClient(client option.HTTPClient) Option { } } +// WithUserAgent sets an explicit User-Agent header, overriding the default and any +// value set via WithHeaders. +func WithUserAgent(ua string) Option { + return func(o *options) { + o.anthropicOptions = append(o.anthropicOptions, anthropic.WithUserAgent(ua)) + } +} + // WithSkipAuth configures whether to skip authentication for the Bedrock provider. func WithSkipAuth(skipAuth bool) Option { return func(o *options) { diff --git a/providers/bedrock/useragent_test.go b/providers/bedrock/useragent_test.go new file mode 100644 index 000000000..d6935d6dc --- /dev/null +++ b/providers/bedrock/useragent_test.go @@ -0,0 +1,154 @@ +package bedrock + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserAgent(t *testing.T) { + t.Parallel() + + newUAServer := func() (*httptest.Server, *[]map[string]string) { + var captured []map[string]string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := make(map[string]string) + for k, v := range r.Header { + if len(v) > 0 { + h[k] = v[0] + } + } + captured = append(captured, h) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(mockAnthropicResponse()) + })) + return server, &captured + } + + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}}, + }, + } + + t.Run("default UA applied", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithAPIKey("k"), + WithSkipAuth(true), + WithHTTPClient(&http.Client{Transport: redirectTransport(server.URL)}), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "us.anthropic.claude-sonnet-4-20250514-v1:0") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "Charm Fantasy/"+fantasy.Version, (*captured)[0]["User-Agent"]) + }) + + t.Run("WithUserAgent wins over default", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithAPIKey("k"), + WithSkipAuth(true), + WithHTTPClient(&http.Client{Transport: redirectTransport(server.URL)}), + WithUserAgent("explicit-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "us.anthropic.claude-sonnet-4-20250514-v1:0") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"]) + }) + + t.Run("WithUserAgent wins over WithHeaders", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithAPIKey("k"), + WithSkipAuth(true), + WithHTTPClient(&http.Client{Transport: redirectTransport(server.URL)}), + WithHeaders(map[string]string{"User-Agent": "from-headers"}), + WithUserAgent("explicit-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "us.anthropic.claude-sonnet-4-20250514-v1:0") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"]) + }) +} + +type redirectRoundTripper struct { + target string +} + +func redirectTransport(target string) *redirectRoundTripper { + return &redirectRoundTripper{target: target} +} + +func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.URL.Scheme = "http" + req.URL.Host = (&redirectRoundTripper{target: rt.target}).host() + return http.DefaultTransport.RoundTrip(req) +} + +func (rt *redirectRoundTripper) host() string { + u := rt.target + if len(u) > 7 && u[:7] == "http://" { + return u[7:] + } + if len(u) > 8 && u[:8] == "https://" { + return u[8:] + } + return u +} + +func mockAnthropicResponse() map[string]any { + return map[string]any{ + "id": "msg_01Test", + "type": "message", + "role": "assistant", + "model": "claude-sonnet-4-20250514", + "content": []any{ + map[string]any{ + "type": "text", + "text": "Hi there", + }, + }, + "stop_reason": "end_turn", + "stop_sequence": "", + "usage": map[string]any{ + "cache_creation": map[string]any{ + "ephemeral_1h_input_tokens": 0, + "ephemeral_5m_input_tokens": 0, + }, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "input_tokens": 5, + "output_tokens": 2, + "server_tool_use": map[string]any{ + "web_search_requests": 0, + }, + "service_tier": "standard", + }, + } +} diff --git a/providers/google/call_useragent.go b/providers/google/call_useragent.go new file mode 100644 index 000000000..a3521dd75 --- /dev/null +++ b/providers/google/call_useragent.go @@ -0,0 +1,53 @@ +package google + +import ( + "context" + "net/http" + + "charm.land/fantasy" + "charm.land/fantasy/providers/internal/httpheaders" +) + +type callUAKey struct{} + +func withCallUA(ctx context.Context, call fantasy.Call) context.Context { + if ua, ok := httpheaders.CallUserAgent(call.UserAgent); ok { + return context.WithValue(ctx, callUAKey{}, ua) + } + return ctx +} + +func withObjectCallUA(ctx context.Context, call fantasy.ObjectCall) context.Context { + if ua, ok := httpheaders.CallUserAgent(call.UserAgent); ok { + return context.WithValue(ctx, callUAKey{}, ua) + } + return ctx +} + +func wrapHTTPClient(c *http.Client) *http.Client { + if c == nil { + c = http.DefaultClient + } + transport := c.Transport + if transport == nil { + transport = http.DefaultTransport + } + return &http.Client{ + Transport: &uaTransport{base: transport}, + CheckRedirect: c.CheckRedirect, + Jar: c.Jar, + Timeout: c.Timeout, + } +} + +type uaTransport struct { + base http.RoundTripper +} + +func (t *uaTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if ua, ok := req.Context().Value(callUAKey{}).(string); ok && ua != "" { + req = req.Clone(req.Context()) + req.Header.Set("User-Agent", ua) + } + return t.base.RoundTrip(req) +} diff --git a/providers/google/google.go b/providers/google/google.go index eedc4237b..5d3f05152 100644 --- a/providers/google/google.go +++ b/providers/google/google.go @@ -14,6 +14,7 @@ import ( "charm.land/fantasy" "charm.land/fantasy/object" "charm.land/fantasy/providers/anthropic" + "charm.land/fantasy/providers/internal/httpheaders" "charm.land/fantasy/schema" "cloud.google.com/go/auth" "github.com/charmbracelet/x/exp/slice" @@ -36,6 +37,7 @@ type options struct { name string baseURL string headers map[string]string + userAgent string client *http.Client backend genai.Backend project string @@ -132,6 +134,14 @@ func WithToolCallIDFunc(f ToolCallIDFunc) Option { } } +// WithUserAgent sets an explicit User-Agent header, overriding the default and any +// value set via WithHeaders. +func WithUserAgent(ua string) Option { + return func(o *options) { + o.userAgent = ua + } +} + // WithObjectMode sets the object generation mode for the Google provider. func WithObjectMode(om fantasy.ObjectMode) Option { return func(o *options) { @@ -154,11 +164,15 @@ type languageModel struct { // LanguageModel implements fantasy.Provider. func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) { if strings.Contains(modelID, "anthropic") || strings.Contains(modelID, "claude") { - p, err := anthropic.New( + anthropicOpts := []anthropic.Option{ anthropic.WithVertex(a.options.project, a.options.location), anthropic.WithHTTPClient(a.options.client), anthropic.WithSkipAuth(a.options.skipAuth), - ) + } + if a.options.userAgent != "" { + anthropicOpts = append(anthropicOpts, anthropic.WithUserAgent(a.options.userAgent)) + } + p, err := anthropic.New(anthropicOpts...) if err != nil { return nil, err } @@ -166,7 +180,7 @@ func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.L } cc := &genai.ClientConfig{ - HTTPClient: a.options.client, + HTTPClient: wrapHTTPClient(a.options.client), Backend: a.options.backend, APIKey: a.options.apiKey, Project: a.options.project, @@ -180,15 +194,16 @@ func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.L } } - if a.options.baseURL != "" || len(a.options.headers) > 0 { - headers := http.Header{} - for k, v := range a.options.headers { - headers.Add(k, v) - } - cc.HTTPOptions = genai.HTTPOptions{ - BaseURL: a.options.baseURL, - Headers: headers, - } + defaultUA := httpheaders.DefaultUserAgent(fantasy.Version) + resolved := httpheaders.ResolveHeaders(a.options.headers, a.options.userAgent, defaultUA) + + headers := http.Header{} + for k, v := range resolved { + headers.Set(k, v) + } + cc.HTTPOptions = genai.HTTPOptions{ + BaseURL: a.options.baseURL, + Headers: headers, } client, err := genai.NewClient(ctx, cc) if err != nil { @@ -530,6 +545,7 @@ func toGooglePrompt(prompt fantasy.Prompt) (*genai.Content, []*genai.Content, [] // Generate implements fantasy.LanguageModel. func (g *languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + ctx = withCallUA(ctx, call) config, contents, warnings, err := g.prepareParams(call) if err != nil { return nil, err @@ -565,6 +581,7 @@ func (g *languageModel) Provider() string { // Stream implements fantasy.LanguageModel. func (g *languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + ctx = withCallUA(ctx, call) config, contents, warnings, err := g.prepareParams(call) if err != nil { return nil, err @@ -891,6 +908,7 @@ func (g *languageModel) StreamObject(ctx context.Context, call fantasy.ObjectCal } func (g *languageModel) generateObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + ctx = withObjectCallUA(ctx, call) // Convert our Schema to Google's JSON Schema format jsonSchemaMap := schema.ToMap(call.Schema) @@ -973,6 +991,7 @@ func (g *languageModel) generateObjectWithJSONMode(ctx context.Context, call fan } func (g *languageModel) streamObjectWithJSONMode(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + ctx = withObjectCallUA(ctx, call) // Convert our Schema to Google's JSON Schema format jsonSchemaMap := schema.ToMap(call.Schema) diff --git a/providers/google/useragent_test.go b/providers/google/useragent_test.go new file mode 100644 index 000000000..1494ff9a8 --- /dev/null +++ b/providers/google/useragent_test.go @@ -0,0 +1,146 @@ +package google + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserAgent(t *testing.T) { + t.Parallel() + + newUAServer := func() (*httptest.Server, *[]map[string]string) { + var captured []map[string]string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := make(map[string]string) + for k, v := range r.Header { + if len(v) > 0 { + h[k] = v[0] + } + } + captured = append(captured, h) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "candidates": []map[string]any{ + { + "content": map[string]any{ + "role": "model", + "parts": []map[string]any{ + {"text": "Hello"}, + }, + }, + "finishReason": "STOP", + }, + }, + "usageMetadata": map[string]any{ + "promptTokenCount": 5, + "candidatesTokenCount": 2, + "totalTokenCount": 7, + }, + }) + })) + return server, &captured + } + + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}}, + }, + } + + findUA := func(captured *[]map[string]string, want string) bool { + for _, h := range *captured { + if ua, ok := h["User-Agent"]; ok && ua == want { + return true + } + } + return false + } + + t.Run("default UA applied", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithVertex("test-project", "us-central1"), + WithBaseURL(server.URL), + WithSkipAuth(true), + ) + require.NoError(t, err) + model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash") + require.NoError(t, err) + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.NotEmpty(t, *captured) + assert.True(t, findUA(captured, "Charm Fantasy/"+fantasy.Version)) + }) + + t.Run("WithUserAgent wins over default", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithVertex("test-project", "us-central1"), + WithBaseURL(server.URL), + WithSkipAuth(true), + WithUserAgent("explicit-ua"), + ) + require.NoError(t, err) + model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash") + require.NoError(t, err) + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.NotEmpty(t, *captured) + assert.True(t, findUA(captured, "explicit-ua")) + }) + + t.Run("WithHeaders User-Agent wins over default", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithVertex("test-project", "us-central1"), + WithBaseURL(server.URL), + WithSkipAuth(true), + WithHeaders(map[string]string{"User-Agent": "custom-from-headers"}), + ) + require.NoError(t, err) + model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash") + require.NoError(t, err) + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.NotEmpty(t, *captured) + assert.True(t, findUA(captured, "custom-from-headers")) + }) + + t.Run("WithUserAgent wins over WithHeaders", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithVertex("test-project", "us-central1"), + WithBaseURL(server.URL), + WithSkipAuth(true), + WithHeaders(map[string]string{"User-Agent": "from-headers"}), + WithUserAgent("explicit-ua"), + ) + require.NoError(t, err) + model, err := p.LanguageModel(t.Context(), "gemini-2.0-flash") + require.NoError(t, err) + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.NotEmpty(t, *captured) + assert.True(t, findUA(captured, "explicit-ua")) + }) +} diff --git a/providers/internal/httpheaders/httpheaders.go b/providers/internal/httpheaders/httpheaders.go new file mode 100644 index 000000000..0bda5f569 --- /dev/null +++ b/providers/internal/httpheaders/httpheaders.go @@ -0,0 +1,57 @@ +// Package httpheaders provides shared User-Agent resolution for all HTTP-based providers. +package httpheaders + +import "strings" + +// DefaultUserAgent returns the default User-Agent string for the SDK. +// The result is "Charm Fantasy/". +func DefaultUserAgent(version string) string { + return "Charm Fantasy/" + version +} + +// ResolveHeaders returns a new header map, with a User-Agent field. +// +// Setting the value via WithUserAgent() takes precedence, however the user +// agent can also be set via HTTP headers (i.e. WithHeaders()). Otherwise, the +// default user agent will be used, i.e. Charm Fantasy/0.11.0. +// +// Also note that the input map is never mutated. +func ResolveHeaders(headers map[string]string, explicitUA, defaultUA string) map[string]string { + out := make(map[string]string, len(headers)+1) + var uaKeys []string + + for k, v := range headers { + out[k] = v + if strings.EqualFold(k, "User-Agent") { + uaKeys = append(uaKeys, k) + } + } + + switch { + case explicitUA != "": + for _, k := range uaKeys { + delete(out, k) + } + out["User-Agent"] = explicitUA + case len(uaKeys) > 0: + val := out[uaKeys[0]] + for _, k := range uaKeys { + delete(out, k) + } + out["User-Agent"] = val + default: + out["User-Agent"] = defaultUA + } + + return out +} + +// CallUserAgent resolves the User-Agent for a single API call. It returns the +// resolved UA string and true if a per-call override should be applied, or +// empty string and false if the client-level UA should be used as-is. +func CallUserAgent(callUA string) (string, bool) { + if callUA != "" { + return callUA, true + } + return "", false +} diff --git a/providers/internal/httpheaders/httpheaders_test.go b/providers/internal/httpheaders/httpheaders_test.go new file mode 100644 index 000000000..d40158ab0 --- /dev/null +++ b/providers/internal/httpheaders/httpheaders_test.go @@ -0,0 +1,146 @@ +package httpheaders + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultUserAgent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + version string + want string + }{ + {name: "basic version", version: "0.11.0", want: "Charm Fantasy/0.11.0"}, + {name: "another version", version: "1.0.0", want: "Charm Fantasy/1.0.0"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := DefaultUserAgent(tt.version) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestResolveHeaders_Precedence(t *testing.T) { + t.Parallel() + + t.Run("explicit UA wins over headers and default", func(t *testing.T) { + t.Parallel() + headers := map[string]string{"User-Agent": "from-headers"} + got := ResolveHeaders(headers, "explicit-ua", "default-ua") + assert.Equal(t, "explicit-ua", got["User-Agent"]) + }) + + t.Run("header UA wins over default", func(t *testing.T) { + t.Parallel() + headers := map[string]string{"User-Agent": "from-headers"} + got := ResolveHeaders(headers, "", "default-ua") + assert.Equal(t, "from-headers", got["User-Agent"]) + }) + + t.Run("default UA used when nothing else set", func(t *testing.T) { + t.Parallel() + got := ResolveHeaders(nil, "", "default-ua") + assert.Equal(t, "default-ua", got["User-Agent"]) + }) + + t.Run("explicit UA wins over case-insensitive header key", func(t *testing.T) { + t.Parallel() + headers := map[string]string{"user-agent": "from-headers"} + got := ResolveHeaders(headers, "explicit-ua", "default-ua") + assert.Equal(t, "explicit-ua", got["User-Agent"]) + _, hasLower := got["user-agent"] + assert.False(t, hasLower, "old case-insensitive key should be removed") + }) + + t.Run("case-insensitive header key canonicalized when no explicit UA", func(t *testing.T) { + t.Parallel() + headers := map[string]string{"user-agent": "from-headers"} + got := ResolveHeaders(headers, "", "default-ua") + assert.Equal(t, "from-headers", got["User-Agent"]) + _, hasLower := got["user-agent"] + assert.False(t, hasLower, "non-canonical key should be removed") + }) +} + +func TestResolveHeaders_NoMutation(t *testing.T) { + t.Parallel() + + original := map[string]string{"X-Custom": "value"} + _ = ResolveHeaders(original, "explicit", "default") + + _, hasUA := original["User-Agent"] + require.False(t, hasUA, "input map must not be mutated") + assert.Equal(t, "value", original["X-Custom"]) +} + +func TestResolveHeaders_PreservesOtherHeaders(t *testing.T) { + t.Parallel() + + headers := map[string]string{ + "X-Custom": "custom-value", + "Authorization": "Bearer token", + } + got := ResolveHeaders(headers, "", "Charm Fantasy/1.0.0") + assert.Equal(t, "custom-value", got["X-Custom"]) + assert.Equal(t, "Bearer token", got["Authorization"]) + assert.Equal(t, "Charm Fantasy/1.0.0", got["User-Agent"]) +} + +func TestResolveHeaders_DuplicateCaseInsensitiveKeys(t *testing.T) { + t.Parallel() + + t.Run("explicit UA removes all variants", func(t *testing.T) { + t.Parallel() + headers := map[string]string{ + "User-Agent": "canonical", + "user-agent": "lowercase", + } + got := ResolveHeaders(headers, "explicit", "default") + assert.Equal(t, "explicit", got["User-Agent"]) + _, hasLower := got["user-agent"] + assert.False(t, hasLower, "all case-insensitive UA keys must be removed") + }) + + t.Run("no explicit UA collapses to single canonical key", func(t *testing.T) { + t.Parallel() + headers := map[string]string{ + "User-Agent": "canonical", + "user-agent": "lowercase", + } + got := ResolveHeaders(headers, "", "default") + _, hasLower := got["user-agent"] + hasCanonical := got["User-Agent"] + assert.False(t, hasLower, "non-canonical key should be removed") + assert.NotEmpty(t, hasCanonical, "canonical User-Agent key must exist") + }) +} + +func TestCallUserAgent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + callUA string + wantUA string + wantOK bool + }{ + {name: "no override", callUA: "", wantUA: "", wantOK: false}, + {name: "explicit UA", callUA: "MyAgent/1.0", wantUA: "MyAgent/1.0", wantOK: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ua, ok := CallUserAgent(tt.callUA) + assert.Equal(t, tt.wantOK, ok) + assert.Equal(t, tt.wantUA, ua) + }) + } +} diff --git a/providers/openai/call_useragent.go b/providers/openai/call_useragent.go new file mode 100644 index 000000000..4c7f2a4f7 --- /dev/null +++ b/providers/openai/call_useragent.go @@ -0,0 +1,25 @@ +package openai + +import ( + "charm.land/fantasy" + "charm.land/fantasy/providers/internal/httpheaders" + "github.com/openai/openai-go/v2/option" +) + +// callUARequestOptions returns per-request options that override the +// client-level User-Agent header when the Call carries agent-level UA settings. +func callUARequestOptions(call fantasy.Call) []option.RequestOption { + if ua, ok := httpheaders.CallUserAgent(call.UserAgent); ok { + return []option.RequestOption{option.WithHeader("User-Agent", ua)} + } + return nil +} + +// objectCallUARequestOptions returns per-request options that override the +// client-level User-Agent header when the ObjectCall carries agent-level UA settings. +func objectCallUARequestOptions(call fantasy.ObjectCall) []option.RequestOption { + if ua, ok := httpheaders.CallUserAgent(call.UserAgent); ok { + return []option.RequestOption{option.WithHeader("User-Agent", ua)} + } + return nil +} diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 9df357ac8..ae3d87649 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -246,7 +246,7 @@ func (o languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas if err != nil { return nil, err } - response, err := o.client.Chat.Completions.New(ctx, *params) + response, err := o.client.Chat.Completions.New(ctx, *params, callUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } @@ -314,7 +314,7 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S IncludeUsage: openai.Bool(true), } - stream := o.client.Chat.Completions.NewStreaming(ctx, *params) + stream := o.client.Chat.Completions.NewStreaming(ctx, *params, callUARequestOptions(call)...) isActiveText := false toolCalls := make(map[int64]streamToolCall) @@ -733,11 +733,10 @@ func (o languageModel) generateObjectWithJSONMode(ctx context.Context, call fant }, } - response, err := o.client.Chat.Completions.New(ctx, *params) + response, err := o.client.Chat.Completions.New(ctx, *params, objectCallUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } - if len(response.Choices) == 0 { usage, _ := o.usageFunc(*response) return nil, &fantasy.NoObjectGeneratedError{ @@ -818,7 +817,7 @@ func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantas IncludeUsage: openai.Bool(true), } - stream := o.client.Chat.Completions.NewStreaming(ctx, *params) + stream := o.client.Chat.Completions.NewStreaming(ctx, *params, objectCallUARequestOptions(call)...) return func(yield func(fantasy.ObjectStreamPart) bool) { if len(warnings) > 0 { diff --git a/providers/openai/openai.go b/providers/openai/openai.go index 7ca74b9c7..928f3dd28 100644 --- a/providers/openai/openai.go +++ b/providers/openai/openai.go @@ -7,6 +7,7 @@ import ( "maps" "charm.land/fantasy" + "charm.land/fantasy/providers/internal/httpheaders" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" ) @@ -30,6 +31,7 @@ type options struct { name string useResponsesAPI bool headers map[string]string + userAgent string client option.HTTPClient sdkOptions []option.RequestOption objectMode fantasy.ObjectMode @@ -132,6 +134,14 @@ func WithUseResponsesAPI() Option { } } +// WithUserAgent sets an explicit User-Agent header, overriding the default and any +// value set via WithHeaders. +func WithUserAgent(ua string) Option { + return func(o *options) { + o.userAgent = ua + } +} + // WithObjectMode sets the object generation mode. func WithObjectMode(om fantasy.ObjectMode) Option { return func(o *options) { @@ -155,7 +165,9 @@ func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.Lan openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL)) } - for key, value := range o.options.headers { + defaultUA := httpheaders.DefaultUserAgent(fantasy.Version) + resolved := httpheaders.ResolveHeaders(o.options.headers, o.options.userAgent, defaultUA) + for key, value := range resolved { openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) } diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 7592b18b2..3e56aeb09 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -12,6 +12,7 @@ import ( "charm.land/fantasy" "github.com/openai/openai-go/v2/packages/param" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -3302,3 +3303,147 @@ func TestParseContextTooLargeError(t *testing.T) { }) } } + +func TestUserAgent(t *testing.T) { + t.Parallel() + + t.Run("default UA applied", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL)) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "Charm Fantasy/"+fantasy.Version, server.calls[0].headers["User-Agent"]) + }) + + t.Run("WithHeaders User-Agent wins over default", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New(WithAPIKey("k"), WithBaseURL(server.server.URL), WithHeaders(map[string]string{"User-Agent": "custom-from-headers"})) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "custom-from-headers", server.calls[0].headers["User-Agent"]) + }) + + t.Run("WithUserAgent wins over both", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithHeaders(map[string]string{"User-Agent": "from-headers"}), + WithUserAgent("explicit-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "explicit-ua", server.calls[0].headers["User-Agent"]) + }) + + t.Run("Call.UserAgent overrides provider WithHeaders UA", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithHeaders(map[string]string{"User-Agent": "header-ua"}), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{ + Prompt: testPrompt, + UserAgent: "call-level-ua", + }) + + require.Len(t, server.calls, 1) + assert.Equal(t, "call-level-ua", server.calls[0].headers["User-Agent"]) + }) + + t.Run("no Call UA falls through to provider UA", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithUserAgent("provider-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: testPrompt}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"]) + }) + + t.Run("agent WithUserAgent overrides provider UA end-to-end", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithUserAgent("provider-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + + agent := fantasy.NewAgent(model, fantasy.WithUserAgent("agent-ua")) + _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "agent-ua", server.calls[0].headers["User-Agent"]) + }) + + t.Run("agent without UA falls through to provider UA end-to-end", func(t *testing.T) { + t.Parallel() + + server := newMockServer() + defer server.close() + server.prepareJSONResponse(map[string]any{}) + + p, err := New( + WithAPIKey("k"), + WithBaseURL(server.server.URL), + WithUserAgent("provider-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "gpt-4") + + agent := fantasy.NewAgent(model) + _, _ = agent.Generate(t.Context(), fantasy.AgentCall{Prompt: "hi"}) + + require.Len(t, server.calls, 1) + assert.Equal(t, "provider-ua", server.calls[0].headers["User-Agent"]) + }) +} diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 39ee8e427..090117de0 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -668,7 +668,7 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { params, warnings := o.prepareParams(call) - response, err := o.client.Responses.New(ctx, *params) + response, err := o.client.Responses.New(ctx, *params, callUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } @@ -806,7 +806,7 @@ func mapResponsesFinishReason(reason string, hasFunctionCall bool) fantasy.Finis func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { params, warnings := o.prepareParams(call) - stream := o.client.Responses.NewStreaming(ctx, *params) + stream := o.client.Responses.NewStreaming(ctx, *params, callUARequestOptions(call)...) finishReason := fantasy.FinishReasonUnknown var usage fantasy.Usage @@ -1106,7 +1106,7 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, } // Make request - response, err := o.client.Responses.New(ctx, *params) + response, err := o.client.Responses.New(ctx, *params, objectCallUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) } @@ -1216,7 +1216,7 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca Format: responses.ResponseFormatTextConfigParamOfJSONSchema(schemaName, jsonSchemaMap), } - stream := o.client.Responses.NewStreaming(ctx, *params) + stream := o.client.Responses.NewStreaming(ctx, *params, objectCallUARequestOptions(call)...) return func(yield func(fantasy.ObjectStreamPart) bool) { if len(warnings) > 0 { diff --git a/providers/openaicompat/openaicompat.go b/providers/openaicompat/openaicompat.go index 3595a6e42..102967634 100644 --- a/providers/openaicompat/openaicompat.go +++ b/providers/openaicompat/openaicompat.go @@ -108,6 +108,14 @@ func WithObjectMode(om fantasy.ObjectMode) Option { } } +// WithUserAgent sets an explicit User-Agent header, overriding the default and any +// value set via WithHeaders. +func WithUserAgent(ua string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithUserAgent(ua)) + } +} + // WithUseResponsesAPI configures the provider to use the responses API for models that support it. func WithUseResponsesAPI() Option { return func(o *options) { diff --git a/providers/openrouter/openrouter.go b/providers/openrouter/openrouter.go index bd0e700af..0b3e51d24 100644 --- a/providers/openrouter/openrouter.go +++ b/providers/openrouter/openrouter.go @@ -89,6 +89,14 @@ func WithHTTPClient(client option.HTTPClient) Option { } } +// WithUserAgent sets an explicit User-Agent header, overriding the default and any +// value set via WithHeaders. +func WithUserAgent(ua string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithUserAgent(ua)) + } +} + // WithObjectMode sets the object generation mode for the OpenRouter provider. // Supported modes: ObjectModeTool, ObjectModeText. // ObjectModeAuto and ObjectModeJSON are automatically converted to ObjectModeTool diff --git a/providers/openrouter/useragent_test.go b/providers/openrouter/useragent_test.go new file mode 100644 index 000000000..15ff2d5e4 --- /dev/null +++ b/providers/openrouter/useragent_test.go @@ -0,0 +1,118 @@ +package openrouter + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/fantasy" + "charm.land/fantasy/providers/openai" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserAgent(t *testing.T) { + t.Parallel() + + newUAServer := func() (*httptest.Server, *[]map[string]string) { + var captured []map[string]string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := make(map[string]string) + for k, v := range r.Header { + if len(v) > 0 { + h[k] = v[0] + } + } + captured = append(captured, h) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(mockOpenAIResponse()) + })) + return server, &captured + } + + withBaseURL := func(url string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithBaseURL(url)) + } + } + + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}}, + }, + } + + t.Run("default UA applied", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New(WithAPIKey("k"), withBaseURL(server.URL)) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "openai/gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "Charm Fantasy/"+fantasy.Version, (*captured)[0]["User-Agent"]) + }) + + t.Run("WithUserAgent wins over default", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New(WithAPIKey("k"), withBaseURL(server.URL), WithUserAgent("explicit-ua")) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "openai/gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"]) + }) + + t.Run("WithUserAgent wins over WithHeaders", func(t *testing.T) { + t.Parallel() + server, captured := newUAServer() + defer server.Close() + + p, err := New( + WithAPIKey("k"), + withBaseURL(server.URL), + WithHeaders(map[string]string{"User-Agent": "from-headers"}), + WithUserAgent("explicit-ua"), + ) + require.NoError(t, err) + model, _ := p.LanguageModel(t.Context(), "openai/gpt-4") + _, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt}) + + require.Len(t, *captured, 1) + assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"]) + }) +} + +func mockOpenAIResponse() map[string]any { + return map[string]any{ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1711115037, + "model": "openai/gpt-4", + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "Hi there", + }, + "finish_reason": "stop", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 4, + "total_tokens": 6, + "completion_tokens": 2, + }, + } +} diff --git a/providers/vercel/vercel.go b/providers/vercel/vercel.go index af87db5eb..43712c1ac 100644 --- a/providers/vercel/vercel.go +++ b/providers/vercel/vercel.go @@ -96,6 +96,14 @@ func WithHTTPClient(client option.HTTPClient) Option { } } +// WithUserAgent sets an explicit User-Agent header, overriding the default and any +// value set via WithHeaders. +func WithUserAgent(ua string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithUserAgent(ua)) + } +} + // WithSDKOptions sets the SDK options for the Vercel provider. func WithSDKOptions(opts ...option.RequestOption) Option { return func(o *options) { diff --git a/version.go b/version.go new file mode 100644 index 000000000..540d1d6de --- /dev/null +++ b/version.go @@ -0,0 +1,12 @@ +package fantasy + +import ( + _ "embed" + "strings" +) + +//go:embed version.txt +var version string + +// Version is the SDK version, read from version.txt. +var Version = strings.TrimSpace(version) diff --git a/version.txt b/version.txt new file mode 100644 index 000000000..d9df1bbc0 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +0.11.0