Skip to content
4 changes: 3 additions & 1 deletion Taskfile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ tasks:
- sh: "[ $(git status --porcelain=2 | wc -l) = 0 ]"
msg: "Git is dirty"
cmds:
- git commit --allow-empty -m "{{.NEXT}}"
- echo {{trimPrefix "v" .NEXT}} > version.txt
- git add version.txt
- git commit -m "{{.NEXT}}"
- git tag --annotate --sign -m "{{.NEXT}}" {{.NEXT}} {{.CLI_ARGS}}
- echo "Pushing {{.NEXT}}..."
- git push origin main --follow-tags
11 changes: 11 additions & 0 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ type agentSettings struct {
presencePenalty *float64
frequencyPenalty *float64
headers map[string]string
userAgent string
providerOptions ProviderOptions

// TODO: add support for provider tools
Expand Down Expand Up @@ -448,6 +449,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
FrequencyPenalty: opts.FrequencyPenalty,
Tools: preparedTools,
ToolChoice: &stepToolChoice,
UserAgent: a.settings.userAgent,
ProviderOptions: opts.ProviderOptions,
})
})
Expand Down Expand Up @@ -829,6 +831,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
FrequencyPenalty: call.FrequencyPenalty,
Tools: preparedTools,
ToolChoice: &stepToolChoice,
UserAgent: a.settings.userAgent,
ProviderOptions: call.ProviderOptions,
}

Expand Down Expand Up @@ -1418,6 +1421,14 @@ func WithHeaders(headers map[string]string) AgentOption {
}
}

// WithUserAgent sets the User-Agent header for the agent. This overrides any
// provider-level User-Agent setting.
func WithUserAgent(ua string) AgentOption {
return func(s *agentSettings) {
s.userAgent = ua
}
}

// WithProviderOptions sets the provider options for the agent.
func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
return func(s *agentSettings) {
Expand Down
71 changes: 71 additions & 0 deletions agent_useragent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package fantasy

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAgent_WithUserAgent_PropagatesOnGenerate(t *testing.T) {
t.Parallel()

var capturedCall Call
model := &mockLanguageModel{
generateFunc: func(_ context.Context, call Call) (*Response, error) {
capturedCall = call
return &Response{
Content: []Content{TextContent{Text: "ok"}},
FinishReason: FinishReasonStop,
}, nil
},
}

agent := NewAgent(model, WithUserAgent("MyApp/2.0"))
_, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"})
require.NoError(t, err)
assert.Equal(t, "MyApp/2.0", capturedCall.UserAgent)
}

func TestAgent_WithUserAgent_PropagatesOnStream(t *testing.T) {
t.Parallel()

var capturedCall Call
model := &mockLanguageModel{
streamFunc: func(_ context.Context, call Call) (StreamResponse, error) {
capturedCall = call
return func(yield func(StreamPart) bool) {
yield(StreamPart{
Type: StreamPartTypeFinish,
FinishReason: FinishReasonStop,
})
}, nil
},
}

agent := NewAgent(model, WithUserAgent("StreamApp/1.0"))
_, err := agent.Stream(context.Background(), AgentStreamCall{Prompt: "hi"})
require.NoError(t, err)
assert.Equal(t, "StreamApp/1.0", capturedCall.UserAgent)
}

func TestAgent_NoUA_OmitsCallLevelFields(t *testing.T) {
t.Parallel()

var capturedCall Call
model := &mockLanguageModel{
generateFunc: func(_ context.Context, call Call) (*Response, error) {
capturedCall = call
return &Response{
Content: []Content{TextContent{Text: "ok"}},
FinishReason: FinishReasonStop,
}, nil
},
}

agent := NewAgent(model)
_, err := agent.Generate(context.Background(), AgentCall{Prompt: "hi"})
require.NoError(t, err)
assert.Empty(t, capturedCall.UserAgent)
}
3 changes: 3 additions & 0 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ type Call struct {
Tools []Tool `json:"tools"`
ToolChoice *ToolChoice `json:"tool_choice"`

// UserAgent overrides the provider-level User-Agent header for this call.
UserAgent string `json:"-"`

// for provider specific options, the key is the provider id
ProviderOptions ProviderOptions `json:"provider_options"`
}
Expand Down
3 changes: 3 additions & 0 deletions object.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type ObjectCall struct {
PresencePenalty *float64
FrequencyPenalty *float64

// UserAgent overrides the provider-level User-Agent header for this call.
UserAgent string `json:"-"`

ProviderOptions ProviderOptions

RepairText schema.ObjectRepairFunc
Expand Down
4 changes: 4 additions & 0 deletions object/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func GenerateWithTool(
TopK: call.TopK,
PresencePenalty: call.PresencePenalty,
FrequencyPenalty: call.FrequencyPenalty,
UserAgent: call.UserAgent,
ProviderOptions: call.ProviderOptions,
})
if err != nil {
Expand Down Expand Up @@ -212,6 +213,7 @@ func GenerateWithText(
TopK: call.TopK,
PresencePenalty: call.PresencePenalty,
FrequencyPenalty: call.FrequencyPenalty,
UserAgent: call.UserAgent,
ProviderOptions: call.ProviderOptions,
})
if err != nil {
Expand Down Expand Up @@ -294,6 +296,7 @@ func StreamWithTool(
TopK: call.TopK,
PresencePenalty: call.PresencePenalty,
FrequencyPenalty: call.FrequencyPenalty,
UserAgent: call.UserAgent,
ProviderOptions: call.ProviderOptions,
})
if err != nil {
Expand Down Expand Up @@ -503,6 +506,7 @@ func StreamWithText(
TopK: call.TopK,
PresencePenalty: call.PresencePenalty,
FrequencyPenalty: call.FrequencyPenalty,
UserAgent: call.UserAgent,
ProviderOptions: call.ProviderOptions,
})
if err != nil {
Expand Down
28 changes: 20 additions & 8 deletions providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"charm.land/fantasy"
"charm.land/fantasy/object"
"charm.land/fantasy/providers/internal/httpheaders"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/charmbracelet/anthropic-sdk-go"
"github.com/charmbracelet/anthropic-sdk-go/bedrock"
Expand All @@ -31,11 +32,12 @@ const (
)

type options struct {
baseURL string
apiKey string
name string
headers map[string]string
client option.HTTPClient
baseURL string
apiKey string
name string
headers map[string]string
userAgent string
client option.HTTPClient

vertexProject string
vertexLocation string
Expand Down Expand Up @@ -125,6 +127,14 @@ func WithHTTPClient(client option.HTTPClient) Option {
}
}

// WithUserAgent sets an explicit User-Agent header, overriding the default and any
// value set via WithHeaders.
func WithUserAgent(ua string) Option {
return func(o *options) {
o.userAgent = ua
}
}

// WithObjectMode sets the object generation mode.
func WithObjectMode(om fantasy.ObjectMode) Option {
return func(o *options) {
Expand All @@ -146,7 +156,9 @@ func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.L
if a.options.baseURL != "" {
clientOptions = append(clientOptions, option.WithBaseURL(a.options.baseURL))
}
for key, value := range a.options.headers {
defaultUA := httpheaders.DefaultUserAgent(fantasy.Version)
resolved := httpheaders.ResolveHeaders(a.options.headers, a.options.userAgent, defaultUA)
for key, value := range resolved {
clientOptions = append(clientOptions, option.WithHeader(key, value))
}
if a.options.client != nil {
Expand Down Expand Up @@ -771,7 +783,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas
if err != nil {
return nil, err
}
response, err := a.client.Messages.New(ctx, *params)
response, err := a.client.Messages.New(ctx, *params, callUARequestOptions(call)...)
if err != nil {
return nil, toProviderErr(err)
}
Expand Down Expand Up @@ -849,7 +861,7 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
return nil, err
}

stream := a.client.Messages.NewStreaming(ctx, *params)
stream := a.client.Messages.NewStreaming(ctx, *params, callUARequestOptions(call)...)
acc := anthropic.Message{}
return func(yield func(fantasy.StreamPart) bool) {
if len(warnings) > 0 {
Expand Down
14 changes: 14 additions & 0 deletions providers/anthropic/call_useragent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package anthropic

import (
"charm.land/fantasy"
"charm.land/fantasy/providers/internal/httpheaders"
"github.com/charmbracelet/anthropic-sdk-go/option"
)

func callUARequestOptions(call fantasy.Call) []option.RequestOption {
if ua, ok := httpheaders.CallUserAgent(call.UserAgent); ok {
return []option.RequestOption{option.WithHeader("User-Agent", ua)}
}
return nil
}
73 changes: 73 additions & 0 deletions providers/anthropic/useragent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package anthropic

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestUserAgent(t *testing.T) {
t.Parallel()

newUAServer := func() (*httptest.Server, *[]map[string]string) {
var captured []map[string]string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := make(map[string]string)
for k, v := range r.Header {
if len(v) > 0 {
h[k] = v[0]
}
}
captured = append(captured, h)

w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(mockAnthropicGenerateResponse())
}))
return server, &captured
}

prompt := fantasy.Prompt{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}},
},
}

t.Run("default UA applied", func(t *testing.T) {
t.Parallel()
server, captured := newUAServer()
defer server.Close()

p, err := New(WithAPIKey("k"), WithBaseURL(server.URL))
require.NoError(t, err)
model, _ := p.LanguageModel(t.Context(), "claude-sonnet-4-20250514")
_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})

require.Len(t, *captured, 1)
assert.Equal(t, "Charm Fantasy/"+fantasy.Version, (*captured)[0]["User-Agent"])
})

t.Run("WithUserAgent wins over both", func(t *testing.T) {
t.Parallel()
server, captured := newUAServer()
defer server.Close()

p, err := New(
WithAPIKey("k"),
WithBaseURL(server.URL),
WithHeaders(map[string]string{"User-Agent": "from-headers"}),
WithUserAgent("explicit-ua"),
)
require.NoError(t, err)
model, _ := p.LanguageModel(t.Context(), "claude-sonnet-4-20250514")
_, _ = model.Generate(t.Context(), fantasy.Call{Prompt: prompt})

require.Len(t, *captured, 1)
assert.Equal(t, "explicit-ua", (*captured)[0]["User-Agent"])
})
}
8 changes: 8 additions & 0 deletions providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ func WithHTTPClient(client option.HTTPClient) Option {
}
}

// WithUserAgent sets an explicit User-Agent header, overriding the default and any
// value set via WithHeaders.
func WithUserAgent(ua string) Option {
return func(o *options) {
o.openaiOptions = append(o.openaiOptions, openai.WithUserAgent(ua))
}
}

// WithUseResponsesAPI configures the provider to use the responses API for models that support it.
func WithUseResponsesAPI() Option {
return func(o *options) {
Expand Down
Loading
Loading