From adea7eee96ef565dc802377297791ebcbc0c48ce Mon Sep 17 00:00:00 2001 From: Aero Date: Sun, 28 Dec 2025 14:39:37 +0800 Subject: [PATCH 1/6] feat(xiaomi): add Xiaomi provider implementation and tests --- providers/xiaomi/xiaomi.go | 94 +++++++++++++++++++++++++++++ providers/xiaomi/xiaomi_test.go | 102 ++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 providers/xiaomi/xiaomi.go create mode 100644 providers/xiaomi/xiaomi_test.go diff --git a/providers/xiaomi/xiaomi.go b/providers/xiaomi/xiaomi.go new file mode 100644 index 000000000..65fbb75e3 --- /dev/null +++ b/providers/xiaomi/xiaomi.go @@ -0,0 +1,94 @@ +// Package xiaomi provides a fantasy.Provider for Xiaomi API. +package xiaomi + +import ( + "net/http" + + "charm.land/fantasy" + "charm.land/fantasy/providers/openaicompat" + openaisdk "github.com/openai/openai-go/v2/option" +) + +const ( + // Name is the provider type name for Xiaomi. + Name = "xiaomi" +) + +type options struct { + baseURL string + apiKey string + headers map[string]string + httpClient *http.Client + extraBody map[string]any +} + +// Option configures the Xiaomi provider. +type Option = func(*options) + +// WithBaseURL sets the base URL for the Xiaomi provider. +func WithBaseURL(baseURL string) Option { + return func(o *options) { + o.baseURL = baseURL + } +} + +// WithAPIKey sets the API key for the Xiaomi provider. +func WithAPIKey(apiKey string) Option { + return func(o *options) { + o.apiKey = apiKey + } +} + +// WithHeaders sets the headers for the Xiaomi provider. +func WithHeaders(headers map[string]string) Option { + return func(o *options) { + o.headers = headers + } +} + +// WithHTTPClient sets the HTTP client for the Xiaomi provider. +func WithHTTPClient(httpClient *http.Client) Option { + return func(o *options) { + o.httpClient = httpClient + } +} + +// WithExtraBody sets the extra body parameters for the Xiaomi provider. +func WithExtraBody(extraBody map[string]any) Option { + return func(o *options) { + o.extraBody = extraBody + } +} + +// New creates a new Xiaomi provider. +func New(opts ...Option) (fantasy.Provider, error) { + o := options{ + baseURL: "https://api.xiaomimimo.com/v1", + headers: make(map[string]string), + extraBody: make(map[string]any), + } + for _, opt := range opts { + opt(&o) + } + + // Build OpenAI-compatible provider with Xiaomi-specific configuration + openaiOpts := []openaicompat.Option{ + openaicompat.WithBaseURL(o.baseURL), + openaicompat.WithAPIKey(o.apiKey), + } + + if len(o.headers) > 0 { + openaiOpts = append(openaiOpts, openaicompat.WithHeaders(o.headers)) + } + + if o.httpClient != nil { + openaiOpts = append(openaiOpts, openaicompat.WithHTTPClient(o.httpClient)) + } + + // Xiaomi thinking logic is handled via extraBody passed to WithSDKOptions + for k, v := range o.extraBody { + openaiOpts = append(openaiOpts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(k, v))) + } + + return openaicompat.New(openaiOpts...) +} \ No newline at end of file diff --git a/providers/xiaomi/xiaomi_test.go b/providers/xiaomi/xiaomi_test.go new file mode 100644 index 000000000..08a084269 --- /dev/null +++ b/providers/xiaomi/xiaomi_test.go @@ -0,0 +1,102 @@ +package xiaomi + +import ( + "testing" + + "charm.land/fantasy" +) + +func TestNew(t *testing.T) { + tests := []struct { + name string + opts []Option + wantErr bool + }{ + { + name: "default options", + opts: []Option{}, + wantErr: false, + }, + { + name: "with custom base URL", + opts: []Option{ + WithBaseURL("https://custom.xiaomi.com/v1"), + }, + wantErr: false, + }, + { + name: "with API key", + opts: []Option{ + WithAPIKey("test-api-key"), + }, + wantErr: false, + }, + { + name: "with headers", + opts: []Option{ + WithHeaders(map[string]string{ + "X-Custom-Header": "value", + }), + }, + wantErr: false, + }, + { + name: "with extra body", + opts: []Option{ + WithExtraBody(map[string]any{ + "thinking": map[string]any{ + "type": "enabled", + }, + }), + }, + wantErr: false, + }, + { + name: "with all options", + opts: []Option{ + WithBaseURL("https://custom.xiaomi.com/v1"), + WithAPIKey("test-api-key"), + WithHeaders(map[string]string{ + "X-Custom-Header": "value", + }), + WithExtraBody(map[string]any{ + "thinking": map[string]any{ + "type": "enabled", + }, + }), + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := New(tt.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if provider == nil { + t.Error("New() returned nil provider") + } + } + }) + } +} + +func TestName(t *testing.T) { + if Name != "xiaomi" { + t.Errorf("Expected Name to be 'xiaomi', got '%s'", Name) + } +} + +func TestProviderImplementsInterface(t *testing.T) { + provider, err := New() + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Verify it implements the Provider interface + var _ fantasy.Provider = provider +} \ No newline at end of file From 0d011edacf69afe04b1f8eeeda03f2331e7b2164 Mon Sep 17 00:00:00 2001 From: Aero Date: Sun, 28 Dec 2025 15:02:03 +0800 Subject: [PATCH 2/6] feat(iflow): implement iFlow provider and corresponding tests --- providers/iflow/iflow.go | 116 +++++++++++++++++++++++++++++ providers/iflow/iflow_test.go | 136 ++++++++++++++++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 providers/iflow/iflow.go create mode 100644 providers/iflow/iflow_test.go diff --git a/providers/iflow/iflow.go b/providers/iflow/iflow.go new file mode 100644 index 000000000..721770a3d --- /dev/null +++ b/providers/iflow/iflow.go @@ -0,0 +1,116 @@ +// Package iflow provides a fantasy.Provider for iFlow API. +package iflow + +import ( + "bytes" + "encoding/json" + "io" + "maps" + "net/http" + + "charm.land/fantasy" + "charm.land/fantasy/providers/openaicompat" +) + +const ( + // Name is the provider type name for iFlow. + Name = "iflow" +) + +type options struct { + baseURL string + apiKey string + headers map[string]string + httpClient *http.Client +} + +// Option configures the iFlow provider. +type Option = func(*options) + +type iflowTransport struct { + base http.RoundTripper +} + +func (t *iflowTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Body == nil || req.Body == http.NoBody || req.Method != http.MethodPost { + return t.base.RoundTrip(req) + } + + // iFlow doesn't like max_tokens in the payload + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + _ = req.Body.Close() + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err == nil { + // Some providers/models fail if max_tokens is present + delete(payload, "max_tokens") + delete(payload, "max_token") + body, _ = json.Marshal(payload) + } + + req.Body = io.NopCloser(bytes.NewReader(body)) + req.ContentLength = int64(len(body)) + return t.base.RoundTrip(req) +} + +// New creates a new iFlow provider. +// iFlow is based on OpenAI-compatible API but requires special User-Agent header. +func New(opts ...Option) (fantasy.Provider, error) { + o := options{ + baseURL: "https://apis.iflow.cn/v1", + headers: make(map[string]string), + } + for _, opt := range opts { + opt(&o) + } + + // iFlow requires "iFlow-Cli" User-Agent for premium models + o.headers["User-Agent"] = "iFlow-Cli" + + // Build OpenAI-compatible provider with iFlow-specific configuration + openaiOpts := []openaicompat.Option{ + openaicompat.WithBaseURL(o.baseURL), + openaicompat.WithAPIKey(o.apiKey), + } + + if len(o.headers) > 0 { + openaiOpts = append(openaiOpts, openaicompat.WithHeaders(o.headers)) + } + + httpClient := o.httpClient + if httpClient == nil { + httpClient = &http.Client{} + } + baseTransport := httpClient.Transport + if baseTransport == nil { + baseTransport = http.DefaultTransport + } + httpClient = &http.Client{ + Transport: &iflowTransport{base: baseTransport}, + Timeout: httpClient.Timeout, + } + openaiOpts = append(openaiOpts, openaicompat.WithHTTPClient(httpClient)) + + return openaicompat.New(openaiOpts...) +} + +// WithBaseURL sets the base URL. +func WithBaseURL(url string) Option { return func(o *options) { o.baseURL = url } } + +// WithAPIKey sets the API key. +func WithAPIKey(key string) Option { return func(o *options) { o.apiKey = key } } + +// WithHeaders sets custom headers. +func WithHeaders(headers map[string]string) Option { + return func(o *options) { + maps.Copy(o.headers, headers) + } +} + +// WithHTTPClient sets a custom HTTP client. +func WithHTTPClient(client *http.Client) Option { + return func(o *options) { o.httpClient = client } +} \ No newline at end of file diff --git a/providers/iflow/iflow_test.go b/providers/iflow/iflow_test.go new file mode 100644 index 000000000..f2b89d8bc --- /dev/null +++ b/providers/iflow/iflow_test.go @@ -0,0 +1,136 @@ +package iflow + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + tests := []struct { + name string + opts []Option + wantErr bool + }{ + { + name: "default options", + opts: []Option{}, + wantErr: false, + }, + { + name: "with custom base URL", + opts: []Option{ + WithBaseURL("https://custom.iflow.com/v1"), + }, + wantErr: false, + }, + { + name: "with API key", + opts: []Option{ + WithAPIKey("test-api-key"), + }, + wantErr: false, + }, + { + name: "with headers", + opts: []Option{ + WithHeaders(map[string]string{ + "X-Custom-Header": "value", + }), + }, + wantErr: false, + }, + { + name: "with HTTP client", + opts: []Option{ + WithHTTPClient(&http.Client{}), + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := New(tt.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if provider == nil { + t.Error("New() returned nil provider") + } + } + }) + } +} + +func TestName(t *testing.T) { + if Name != "iflow" { + t.Errorf("Expected Name to be 'iflow', got '%s'", Name) + } +} + +func TestIFlowTransport(t *testing.T) { + // Create a mock server to receive the request + var capturedBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(body, &capturedBody); err != nil { + t.Fatal(err) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"hello"}}]}`)) + })) + defer server.Close() + + // Create the transport + transport := &iflowTransport{ + base: http.DefaultTransport, + } + + // Create a request with max_tokens and max_token + payload := map[string]any{ + "model": "test-model", + "messages": []any{map[string]any{"role": "user", "content": "hi"}}, + "max_tokens": 100, + "max_token": 100, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, server.URL, bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + // Send the request through the transport + client := &http.Client{Transport: transport} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Verify that max_tokens and max_token were removed + assert.NotContains(t, capturedBody, "max_tokens") + assert.NotContains(t, capturedBody, "max_token") + assert.Equal(t, "test-model", capturedBody["model"]) +} + +func TestProviderImplementsInterface(t *testing.T) { + provider, err := New() + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Verify it implements the Provider interface + var _ fantasy.Provider = provider +} \ No newline at end of file From c7d6e6a6295e14e5d0998d6313eb0a0898103910 Mon Sep 17 00:00:00 2001 From: Aero Date: Sun, 28 Dec 2025 15:07:28 +0800 Subject: [PATCH 3/6] feat(xiaomi): add thinking mode option to Xiaomi provider and corresponding tests --- providers/xiaomi/xiaomi.go | 14 ++++++++++++++ providers/xiaomi/xiaomi_test.go | 7 +++++++ 2 files changed, 21 insertions(+) diff --git a/providers/xiaomi/xiaomi.go b/providers/xiaomi/xiaomi.go index 65fbb75e3..03a4be3c4 100644 --- a/providers/xiaomi/xiaomi.go +++ b/providers/xiaomi/xiaomi.go @@ -20,6 +20,7 @@ type options struct { headers map[string]string httpClient *http.Client extraBody map[string]any + thinking bool } // Option configures the Xiaomi provider. @@ -60,6 +61,13 @@ func WithExtraBody(extraBody map[string]any) Option { } } +// WithThinking enables or disables thinking mode for the Xiaomi provider. +func WithThinking(enabled bool) Option { + return func(o *options) { + o.thinking = enabled + } +} + // New creates a new Xiaomi provider. func New(opts ...Option) (fantasy.Provider, error) { o := options{ @@ -86,6 +94,12 @@ func New(opts ...Option) (fantasy.Provider, error) { } // Xiaomi thinking logic is handled via extraBody passed to WithSDKOptions + if o.thinking { + openaiOpts = append(openaiOpts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet("thinking", map[string]any{ + "type": "enabled", + }))) + } + for k, v := range o.extraBody { openaiOpts = append(openaiOpts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(k, v))) } diff --git a/providers/xiaomi/xiaomi_test.go b/providers/xiaomi/xiaomi_test.go index 08a084269..b9b192164 100644 --- a/providers/xiaomi/xiaomi_test.go +++ b/providers/xiaomi/xiaomi_test.go @@ -51,6 +51,13 @@ func TestNew(t *testing.T) { }, wantErr: false, }, + { + name: "with thinking enabled", + opts: []Option{ + WithThinking(true), + }, + wantErr: false, + }, { name: "with all options", opts: []Option{ From 5af021d42569db721931a2e2b6e5bf3bb9406a61 Mon Sep 17 00:00:00 2001 From: Aero Date: Mon, 29 Dec 2025 01:05:51 +0800 Subject: [PATCH 4/6] feat(openai): handle tool calls without type field in streaming response --- providers/openai/language_model.go | 2 +- providers/openai/openai_test.go | 73 ++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 55227bd92..58b7d7cc5 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -411,7 +411,7 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } } else { var err error - if toolCallDelta.Type != "function" { + if toolCallDelta.Type != "" && toolCallDelta.Type != "function" { err = &fantasy.Error{Title: "invalid provider response", Message: "expected 'function' type."} } if toolCallDelta.ID == "" { diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 97296fd97..e9f9368bc 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -2355,6 +2355,79 @@ func TestDoStream(t *testing.T) { require.Equal(t, `{"value":"Sparkle Day"}`, fullInput) }) + t.Run("should stream tool deltas without type field (devstral-style)", func(t *testing.T) { + t.Parallel() + + server := newStreamingMockServer() + defer server.close() + + // Simulate devstral-style response: tool_calls without type field, finish_reason in same chunk + chunks := []string{ + `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1711357598,"model":"devstral-2512","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}` + "\n\n", + `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1711357598,"model":"devstral-2512","choices":[{"index":0,"delta":{"tool_calls":[{"id":"tg7UYnLaz","function":{"name":"grep","arguments":"{\"pattern\": \"devstral\", \"literal_text\": true}"},"index":0}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":14797,"total_tokens":14815,"completion_tokens":18}}` + "\n\n", + "data: [DONE]\n\n", + } + server.chunks = chunks + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), + ) + require.NoError(t, err) + model, _ := provider.LanguageModel(t.Context(), "devstral-2512") + + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: testPrompt, + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ + Name: "grep", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "pattern": map[string]any{ + "type": "string", + }, + "literal_text": map[string]any{ + "type": "boolean", + }, + }, + "required": []string{"pattern"}, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#", + }, + }, + }, + }) + + require.NoError(t, err) + + parts, err := collectStreamParts(stream) + require.NoError(t, err) + + // Find tool-related parts + var toolCall *fantasy.StreamPart + var finishPart *fantasy.StreamPart + + for i := range parts { + if parts[i].Type == fantasy.StreamPartTypeToolCall { + toolCall = &parts[i] + } + if parts[i].Type == fantasy.StreamPartTypeFinish { + finishPart = &parts[i] + } + } + + // Verify tool call was processed + require.NotNil(t, toolCall, "tool call should be present") + require.Equal(t, "tg7UYnLaz", toolCall.ID) + require.Equal(t, "grep", toolCall.ToolCallName) + require.Equal(t, `{"pattern": "devstral", "literal_text": true}`, toolCall.ToolCallInput) + + // Verify finish reason was set correctly + require.NotNil(t, finishPart, "finish part should be present") + require.Equal(t, fantasy.FinishReasonToolCalls, finishPart.FinishReason) + }) + t.Run("should stream annotations/citations", func(t *testing.T) { t.Parallel() From 76edb2579f8124356faa02ec2b3af20e70e3a1b7 Mon Sep 17 00:00:00 2001 From: Aero Date: Mon, 29 Dec 2025 01:16:41 +0800 Subject: [PATCH 5/6] feat(openaicompat): implement tool call ID normalization and add corresponding tests --- .../openaicompat/language_model_hooks.go | 70 ++++++++++- providers/openaicompat/normalization_test.go | 114 ++++++++++++++++++ 2 files changed, 180 insertions(+), 4 deletions(-) create mode 100644 providers/openaicompat/normalization_test.go diff --git a/providers/openaicompat/language_model_hooks.go b/providers/openaicompat/language_model_hooks.go index 0aa721d5a..f79eb3400 100644 --- a/providers/openaicompat/language_model_hooks.go +++ b/providers/openaicompat/language_model_hooks.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "regexp" "strings" "charm.land/fantasy" @@ -15,6 +16,61 @@ import ( const reasoningStartedCtx = "reasoning_started" +// normalizeToolCallID normalizes tool call IDs to ensure compatibility with different providers. +// It converts UUID-style IDs and other formats to a 9-character alphanumeric format +// that is compatible with providers like Mistral that have specific ID requirements. +func normalizeToolCallID(id string) string { + // Remove all non-alphanumeric characters first + re := regexp.MustCompile(`[^a-zA-Z0-9]`) + cleaned := re.ReplaceAllString(id, "") + + // If we have at least 9 alphanumeric characters, use the first 9 + if len(cleaned) >= 9 { + return cleaned[:9] + } + + // If we have some alphanumeric characters but fewer than 9, pad them + if len(cleaned) > 0 { + return padToLength(cleaned, 9) + } + + // If we have no alphanumeric characters at all, generate a default ID + // Use a hash-like approach based on the original string to ensure consistency + return generateDefaultID(id) +} + +// padToLength pads a string to the specified length with deterministic alphanumeric characters +func padToLength(s string, length int) string { + if len(s) >= length { + return s[:length] + } + + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + result := []byte(s) + for i := len(s); i < length; i++ { + // Use a simple hash of the input string to ensure deterministic padding + result = append(result, charset[(i+len(s))%len(charset)]) + } + return string(result) +} + +// generateDefaultID generates a default 9-character alphanumeric ID for edge cases +func generateDefaultID(input string) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + result := make([]byte, 9) + + // Use a simple hash of the input string to generate deterministic IDs + for i := 0; i < 9; i++ { + if i < len(input) { + result[i] = charset[int(input[i])%len(charset)] + } else { + // Pad with a pattern for very short inputs + result[i] = charset[(i*7)%len(charset)] + } + } + return string(result) +} + // PrepareCallFunc prepares the call for the language model. func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) { providerOptions := &ProviderOptions{} @@ -346,10 +402,12 @@ func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletio }) continue } - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, + // Normalize tool call ID to ensure compatibility with different providers + normalizedToolCallID := normalizeToolCallID(toolCallPart.ToolCallID) + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, openaisdk.ChatCompletionMessageToolCallUnionParam{ OfFunction: &openaisdk.ChatCompletionMessageFunctionToolCallParam{ - ID: toolCallPart.ToolCallID, + ID: normalizedToolCallID, Type: "function", Function: openaisdk.ChatCompletionMessageFunctionToolCallFunctionParam{ Name: toolCallPart.ToolName, @@ -404,7 +462,9 @@ func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletio }) continue } - messages = append(messages, openaisdk.ToolMessage(output.Text, toolResultPart.ToolCallID)) + // Normalize tool call ID for tool results as well + normalizedToolCallID := normalizeToolCallID(toolResultPart.ToolCallID) + messages = append(messages, openaisdk.ToolMessage(output.Text, normalizedToolCallID)) case fantasy.ToolResultContentTypeError: output, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](toolResultPart.Output) if !ok { @@ -414,7 +474,9 @@ func ToPromptFunc(prompt fantasy.Prompt, _, _ string) ([]openaisdk.ChatCompletio }) continue } - messages = append(messages, openaisdk.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID)) + // Normalize tool call ID for error results as well + normalizedToolCallID := normalizeToolCallID(toolResultPart.ToolCallID) + messages = append(messages, openaisdk.ToolMessage(output.Error.Error(), normalizedToolCallID)) } } } diff --git a/providers/openaicompat/normalization_test.go b/providers/openaicompat/normalization_test.go new file mode 100644 index 000000000..e4a37e7dd --- /dev/null +++ b/providers/openaicompat/normalization_test.go @@ -0,0 +1,114 @@ +package openaicompat + +import ( + "testing" +) + +func TestNormalizeToolCallID(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "valid 9-char alphanumeric", + input: "abc123XYZ", + expected: "abc123XYZ", + }, + { + name: "UUID-style with underscore", + input: "call_1d9af98d68f24568a1aefd62", + expected: "call1d9af", + }, + { + name: "UUID-style with hyphens", + input: "c7f85f48-1ca1-4b3a-90aa-3580551a814f", + expected: "c7f85f481", + }, + { + name: "long string", + input: "thisisaverylongtoolcallidthatexceeds9characters", + expected: "thisisave", + }, + { + name: "short string", + input: "abc", + expected: "abcghijkl", + }, + { + name: "mixed special characters", + input: "call_123-456_ABC", + expected: "call12345", + }, + { + name: "exactly 9 chars with special", + input: "call_12_3", + expected: "call123op", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeToolCallID(tt.input) + if result != tt.expected { + t.Errorf("normalizeToolCallID(%q) = %q, want %q", tt.input, result, tt.expected) + } + // Verify the result is always 9 characters + if len(result) != 9 { + t.Errorf("normalizeToolCallID(%q) returned length %d, want 9", tt.input, len(result)) + } + // Verify the result only contains alphanumeric characters + for _, c := range result { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { + t.Errorf("normalizeToolCallID(%q) contains invalid character: %c", tt.input, c) + } + } + }) + } +} + +func TestNormalizeToolCallIDConsistency(t *testing.T) { + // Test that the same input always produces the same output + input := "call_1d9af98d68f24568a1aefd62" + result1 := normalizeToolCallID(input) + result2 := normalizeToolCallID(input) + if result1 != result2 { + t.Errorf("normalizeToolCallID(%q) is not consistent: %q vs %q", input, result1, result2) + } +} + +func TestNormalizeToolCallIDEdgeCases(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "empty string", + input: "", + }, + { + name: "only special characters", + input: "---___", + }, + { + name: "only underscores", + input: "_________", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeToolCallID(tt.input) + // Should still return a 9-character alphanumeric string + if len(result) != 9 { + t.Errorf("normalizeToolCallID(%q) returned length %d, want 9", tt.input, len(result)) + } + // Verify the result only contains alphanumeric characters + for _, c := range result { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { + t.Errorf("normalizeToolCallID(%q) contains invalid character: %c", tt.input, c) + } + } + }) + } +} \ No newline at end of file From df0c8eee5515a50ffd57d6ac4c477bdc21488db6 Mon Sep 17 00:00:00 2001 From: Aero Date: Mon, 29 Dec 2025 01:48:50 +0800 Subject: [PATCH 6/6] fix(xiaomi): extract actual tool name from wrapper function command parameter --- providers/openaicompat/openaicompat.go | 8 + providers/xiaomi/xiaomi.go | 373 ++++++++++++++++++++++++- providers/xiaomi/xiaomi_test.go | 239 +++++++++++++++- 3 files changed, 610 insertions(+), 10 deletions(-) diff --git a/providers/openaicompat/openaicompat.go b/providers/openaicompat/openaicompat.go index 236d43153..f05c3e6e4 100644 --- a/providers/openaicompat/openaicompat.go +++ b/providers/openaicompat/openaicompat.go @@ -107,3 +107,11 @@ func WithObjectMode(om fantasy.ObjectMode) Option { o.objectMode = om } } + +// WithLanguageModelOption adds a custom language model option to the OpenAI-compatible provider. +// This can be used to override default behaviors like StreamExtraFunc. +func WithLanguageModelOption(opt openai.LanguageModelOption) Option { + return func(o *options) { + o.languageModelOptions = append(o.languageModelOptions, opt) + } +} diff --git a/providers/xiaomi/xiaomi.go b/providers/xiaomi/xiaomi.go index 03a4be3c4..53567d53b 100644 --- a/providers/xiaomi/xiaomi.go +++ b/providers/xiaomi/xiaomi.go @@ -2,11 +2,17 @@ package xiaomi import ( + "context" + "encoding/json" + "fmt" "net/http" + "regexp" + "strings" "charm.land/fantasy" "charm.land/fantasy/providers/openaicompat" - openaisdk "github.com/openai/openai-go/v2/option" + "charm.land/fantasy/providers/openai" + openaisdk "github.com/openai/openai-go/v2" ) const ( @@ -14,6 +20,12 @@ const ( Name = "xiaomi" ) +type xiaomiProvider struct { + fantasy.Provider + extraBody map[string]any + thinking bool +} + type options struct { baseURL string apiKey string @@ -68,6 +80,264 @@ func WithThinking(enabled bool) Option { } } +const ( + xiaomiToolCallPrefix = "xiaomi_tool_calls" +) + +// xiaomiToolCall represents a parsed Xiaomi XML tool call +type xiaomiToolCall struct { + name string + arguments string +} + +// xiaomiPrepareCallFunc adds Xiaomi-specific parameters to the request +func xiaomiPrepareCallFunc(opts *options) openai.LanguageModelPrepareCallFunc { + return func(model fantasy.LanguageModel, params *openaisdk.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) { + // First delegate to openaicompat's default prepare + warnings, err := openaicompat.PrepareCallFunc(model, params, call) + if err != nil { + return warnings, err + } + + // Create extra fields map + extraFields := make(map[string]any) + + // Add thinking parameter if enabled + if opts.thinking { + extraFields["thinking"] = map[string]any{ + "type": "enabled", + } + } + + // Add extra body parameters + if len(opts.extraBody) > 0 { + for k, v := range opts.extraBody { + extraFields[k] = v + } + } + + // Set extra fields if any + if len(extraFields) > 0 { + params.SetExtraFields(extraFields) + } + + return warnings, nil + } +} + +// xiaomiStreamExtraFunc handles Xiaomi-specific streaming responses, including XML tool call format +func xiaomiStreamExtraFunc(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool) { + if len(chunk.Choices) == 0 { + return ctx, true + } + + for inx, choice := range chunk.Choices { + // Check for Xiaomi XML tool calls in content + if choice.Delta.Content != "" && strings.Contains(choice.Delta.Content, " 0 { + return processXiaomiToolCalls(chunk, wrapYieldToExtractToolName(yield), ctx, inx) + } + } + + // Delegate to default openaicompat behavior for non-tool-call content + return openaicompat.StreamExtraFunc(chunk, yield, ctx) +} + +// wrapYieldToSuppressContent creates a yield wrapper that suppresses content events when in tool call mode +func wrapYieldToSuppressContent(yield func(fantasy.StreamPart) bool, ctx map[string]any) func(fantasy.StreamPart) bool { + return func(sp fantasy.StreamPart) bool { + // Suppress content events when we're processing tool calls + if sp.Type == fantasy.StreamPartTypeTextDelta { + // Check if we're in tool call mode by looking at accumulated content + accumulatedKey := xiaomiToolCallPrefix + "_content" + if accumulated, ok := ctx[accumulatedKey].(string); ok && accumulated != "" { + // We're in tool call mode, suppress this content + return true + } + } + return yield(sp) + } +} + +// wrapYieldToExtractToolName creates a yield wrapper that extracts the actual tool name from wrapper functions +func wrapYieldToExtractToolName(yield func(fantasy.StreamPart) bool) func(fantasy.StreamPart) bool { + return func(sp fantasy.StreamPart) bool { + // Only process tool call events + if sp.Type == fantasy.StreamPartTypeToolCall && sp.ToolCallName != "" { + // Check if this is a wrapper function + if sp.ToolCallName == "editor" || sp.ToolCallName == "bash" || sp.ToolCallName == "agent" { + // Parse the arguments to extract the command parameter + var argsMap map[string]string + if err := json.Unmarshal([]byte(sp.ToolCallInput), &argsMap); err == nil { + if command, ok := argsMap["command"]; ok { + // Use command parameter as tool name + sp.ToolCallName = command + // Remove command from arguments + delete(argsMap, "command") + if newArgs, err := json.Marshal(argsMap); err == nil { + sp.ToolCallInput = string(newArgs) + } + } + } + } + } + return yield(sp) + } +} + +// processXiaomiToolCalls processes tool calls from the standard ToolCalls field +func processXiaomiToolCalls(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any, inx int) (map[string]any, bool) { + // Delegate to default openaicompat behavior, but with tool name extraction + return openaicompat.StreamExtraFunc(chunk, yield, ctx) +} + +// parseXiaomiToolCalls parses Xiaomi's XML tool call format and emits standard tool call events +func parseXiaomiToolCalls(chunk openaisdk.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any, inx int) (map[string]any, bool) { + content := chunk.Choices[0].Delta.Content + + // Accumulate content across chunks + accumulatedKey := xiaomiToolCallPrefix + "_content" + accumulated, _ := ctx[accumulatedKey].(string) + accumulated += content + ctx[accumulatedKey] = accumulated + + // Try to parse complete tool calls + toolCalls, remainingContent, err := extractXiaomiToolCalls(accumulated) + if err != nil { + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: &fantasy.Error{Title: "parse error", Message: "error parsing Xiaomi tool calls", Cause: err}, + }) + return ctx, false + } + + // Only update context with remaining content if we found tool calls + // This ensures we keep accumulating when no complete tool calls are found yet + if len(toolCalls) > 0 { + ctx[accumulatedKey] = remainingContent + } + + // Emit tool call events for each parsed tool + for _, tc := range toolCalls { + // Xiaomi uses wrapper functions (e.g., "editor") where the actual tool name + // is in the "command" parameter + toolName := tc.name + toolArgs := tc.arguments + + // Parse arguments to extract command if present + var argsMap map[string]string + if err := json.Unmarshal([]byte(tc.arguments), &argsMap); err == nil { + if command, ok := argsMap["command"]; ok && (tc.name == "editor" || tc.name == "bash" || tc.name == "agent") { + // Use command parameter as tool name for wrapper functions + toolName = command + // Remove command from arguments + delete(argsMap, "command") + if newArgs, err := json.Marshal(argsMap); err == nil { + toolArgs = string(newArgs) + } + } + } + + toolCallID := fmt.Sprintf("xiaomi_%d_%s", inx, toolName) + + // Emit tool input start + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, + ID: toolCallID, + }) { + return ctx, false + } + + // Emit tool input delta (the arguments) + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputDelta, + ID: toolCallID, + Delta: toolArgs, + }) { + return ctx, false + } + + // Emit tool input end + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, + ID: toolCallID, + }) { + return ctx, false + } + + // Emit tool call + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, + ID: toolCallID, + ToolCallName: toolName, + ToolCallInput: toolArgs, + }) { + return ctx, false + } + } + + return ctx, true +} + +// extractXiaomiToolCalls extracts complete tool calls from accumulated content +func extractXiaomiToolCalls(content string) ([]xiaomiToolCall, string, error) { + var toolCalls []xiaomiToolCall + + // Pattern to match ... (with DOTALL flag to match newlines) + funcPattern := regexp.MustCompile(`(?s)]+)>(.*?)`) + paramsPattern := regexp.MustCompile(`]+)>([^<]*)`) + + // Find all complete tool calls + matches := funcPattern.FindAllStringSubmatchIndex(content, -1) + + // If no matches, return the original content as remaining + if len(matches) == 0 { + return toolCalls, content, nil + } + + // Process each complete tool call + for _, match := range matches { + if match[1] == -1 || match[2] == -1 { + continue + } + + funcName := content[match[2]:match[3]] + funcBody := content[match[4]:match[5]] + + // Parse parameters + params := make(map[string]string) + paramMatches := paramsPattern.FindAllStringSubmatch(funcBody, -1) + for _, pm := range paramMatches { + if len(pm) == 3 { + params[pm[1]] = pm[2] + } + } + + // Convert to JSON arguments + argsJSON, err := json.Marshal(params) + if err != nil { + return nil, "", fmt.Errorf("failed to marshal parameters: %w", err) + } + + toolCalls = append(toolCalls, xiaomiToolCall{ + name: funcName, + arguments: string(argsJSON), + }) + } + + // Remaining content is after the last complete tool call + lastMatch := matches[len(matches)-1] + remaining := content[lastMatch[1]:] + + return toolCalls, remaining, nil +} + // New creates a new Xiaomi provider. func New(opts ...Option) (fantasy.Provider, error) { o := options{ @@ -93,16 +363,101 @@ func New(opts ...Option) (fantasy.Provider, error) { openaiOpts = append(openaiOpts, openaicompat.WithHTTPClient(o.httpClient)) } - // Xiaomi thinking logic is handled via extraBody passed to WithSDKOptions - if o.thinking { - openaiOpts = append(openaiOpts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet("thinking", map[string]any{ - "type": "enabled", - }))) + // Override PrepareCallFunc to add Xiaomi-specific parameters (thinking, extra body) + openaiOpts = append(openaiOpts, openaicompat.WithLanguageModelOption(openai.WithLanguageModelPrepareCallFunc(xiaomiPrepareCallFunc(&o)))) + + // Override StreamExtraFunc to handle Xiaomi's XML tool call format + openaiOpts = append(openaiOpts, openaicompat.WithLanguageModelOption(openai.WithLanguageModelStreamExtraFunc(xiaomiStreamExtraFunc))) + + provider, err := openaicompat.New(openaiOpts...) + if err != nil { + return nil, err + } + + // Wrap provider to filter content before streaming + return &xiaomiProviderWrapper{ + Provider: provider, + xiaomiOpts: &o, + }, nil +} + +// xiaomiProviderWrapper wraps the OpenAI-compatible provider to filter XML tool call content +type xiaomiProviderWrapper struct { + fantasy.Provider + xiaomiOpts *options +} + +// LanguageModel implements fantasy.Provider by wrapping the language model with content filtering +func (w *xiaomiProviderWrapper) LanguageModel(ctx context.Context, modelID string) (fantasy.LanguageModel, error) { + lm, err := w.Provider.LanguageModel(ctx, modelID) + if err != nil { + return nil, err } - for k, v := range o.extraBody { - openaiOpts = append(openaiOpts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(k, v))) + return &xiaomiLanguageModel{ + LanguageModel: lm, + xiaomiOpts: w.xiaomiOpts, + }, nil +} + +// xiaomiLanguageModel wraps the language model to filter streaming content +type xiaomiLanguageModel struct { + fantasy.LanguageModel + xiaomiOpts *options +} + +// Stream implements fantasy.LanguageModel by filtering XML tool call content +func (m *xiaomiLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + stream, err := m.LanguageModel.Stream(ctx, call) + if err != nil { + return nil, err } + return xiaomiFilteredStream(stream), nil +} + +// xiaomiFilteredStream creates a filtered stream that suppresses tool call content +func xiaomiFilteredStream(original fantasy.StreamResponse) fantasy.StreamResponse { + return func(yield func(fantasy.StreamPart) bool) { + // Track accumulated content to detect tool calls + var accumulatedContent strings.Builder + var inToolCallMode bool + + // Create a wrapper yield that filters content + filteredYield := func(sp fantasy.StreamPart) bool { + // Handle content deltas + if sp.Type == fantasy.StreamPartTypeTextDelta && sp.Delta != "" { + // Check if this content contains tool call markers + if strings.Contains(sp.Delta, "view/path/to/file`, + expectedCalls: []xiaomiToolCall{ + { + name: "editor", + arguments: `{"command":"view","file_path":"/path/to/file"}`, + }, + }, + expectedRemain: "", + wantErr: false, + }, + { + name: "multiple tool calls", + content: `viewhello`, + expectedCalls: []xiaomiToolCall{ + { + name: "editor", + arguments: `{"command":"view"}`, + }, + { + name: "write", + arguments: `{"content":"hello"}`, + }, + }, + expectedRemain: "", + wantErr: false, + }, + { + name: "incomplete tool call", + content: `view`, + expectedCalls: []xiaomiToolCall{}, + expectedRemain: `view`, + wantErr: false, + }, + { + name: "tool call with special characters", + content: `/path/with spaces`, + expectedCalls: []xiaomiToolCall{ + { + name: "test", + arguments: `{"path":"/path/with spaces"}`, + }, + }, + expectedRemain: "", + wantErr: false, + }, + { + name: "tool call with newlines (actual Xiaomi format)", + content: "\n\nview\n/Users/aero/Documents/charm/catwalk/internal/providers/providers.go\n\n", + expectedCalls: []xiaomiToolCall{ + { + name: "editor", + arguments: `{"command":"view","file_path":"/Users/aero/Documents/charm/catwalk/internal/providers/providers.go"}`, + }, + }, + expectedRemain: "\n", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + calls, remain, err := extractXiaomiToolCalls(tt.content) + if (err != nil) != tt.wantErr { + t.Errorf("extractXiaomiToolCalls() error = %v, wantErr %v", err, tt.wantErr) + return + } + if len(calls) != len(tt.expectedCalls) { + t.Errorf("extractXiaomiToolCalls() got %d calls, want %d", len(calls), len(tt.expectedCalls)) + return + } + for i, call := range calls { + if call.name != tt.expectedCalls[i].name { + t.Errorf("extractXiaomiToolCalls() call[%d].name = %v, want %v", i, call.name, tt.expectedCalls[i].name) + } + if call.arguments != tt.expectedCalls[i].arguments { + t.Errorf("extractXiaomiToolCalls() call[%d].arguments = %v, want %v", i, call.arguments, tt.expectedCalls[i].arguments) + } + } + if remain != tt.expectedRemain { + t.Errorf("extractXiaomiToolCalls() remain = %v, want %v", remain, tt.expectedRemain) + } + }) + } +} + func TestNew(t *testing.T) { tests := []struct { name string @@ -103,7 +199,148 @@ func TestProviderImplementsInterface(t *testing.T) { if err != nil { t.Fatalf("Failed to create provider: %v", err) } - + // Verify it implements the Provider interface var _ fantasy.Provider = provider +} + +func TestToolNameExtraction(t *testing.T) { + // Test that wrapper functions have their command parameter extracted + testCases := []struct { + name string + xml string + expectedTool string + }{ + { + name: "editor wrapper with view command", + xml: `view/path/to/file`, + expectedTool: "view", + }, + { + name: "editor wrapper with write command", + xml: `write/path/to/filehello`, + expectedTool: "write", + }, + { + name: "bash wrapper with ls command", + xml: `ls/tmp`, + expectedTool: "ls", + }, + { + name: "non-wrapper function", + xml: `/path/to/filehello`, + expectedTool: "write", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Extract tool calls + toolCalls, _, err := extractXiaomiToolCalls(tc.xml) + if err != nil { + t.Fatalf("extractXiaomiToolCalls() error = %v", err) + } + if len(toolCalls) != 1 { + t.Fatalf("extractXiaomiToolCalls() got %d calls, want 1", len(toolCalls)) + } + + // Simulate the tool name extraction logic from parseXiaomiToolCalls + parsedTC := toolCalls[0] + toolName := parsedTC.name + + // Parse arguments to extract command if present + var argsMap map[string]string + if err := json.Unmarshal([]byte(parsedTC.arguments), &argsMap); err == nil { + if command, ok := argsMap["command"]; ok && (parsedTC.name == "editor" || parsedTC.name == "bash" || parsedTC.name == "agent") { + // Use command parameter as tool name for wrapper functions + toolName = command + // Remove command from arguments + delete(argsMap, "command") + } + } + + // Verify the tool name was extracted correctly + if toolName != tc.expectedTool { + t.Errorf("Expected tool name '%s', got '%s'", tc.expectedTool, toolName) + } + }) + } +} + +func TestWrapYieldToExtractToolName(t *testing.T) { + // Test that the yield wrapper correctly extracts tool names from wrapper functions + testCases := []struct { + name string + inputToolCall fantasy.StreamPart + expectedName string + expectedInput string + }{ + { + name: "editor wrapper with view command", + inputToolCall: fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, + ID: "test_id", + ToolCallName: "editor", + ToolCallInput: `{"command":"view","file_path":"/path/to/file"}`, + }, + expectedName: "view", + expectedInput: `{"file_path":"/path/to/file"}`, + }, + { + name: "bash wrapper with ls command", + inputToolCall: fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, + ID: "test_id", + ToolCallName: "bash", + ToolCallInput: `{"command":"ls","path":"/tmp"}`, + }, + expectedName: "ls", + expectedInput: `{"path":"/tmp"}`, + }, + { + name: "non-wrapper function", + inputToolCall: fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, + ID: "test_id", + ToolCallName: "write", + ToolCallInput: `{"file_path":"/path/to/file","content":"hello"}`, + }, + expectedName: "write", + expectedInput: `{"file_path":"/path/to/file","content":"hello"}`, + }, + { + name: "non-tool-call event", + inputToolCall: fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, + Delta: "hello", + }, + expectedName: "", + expectedInput: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a yield function that captures the modified stream part + var captured fantasy.StreamPart + yield := func(sp fantasy.StreamPart) bool { + captured = sp + return true + } + + // Apply the wrapper + wrapped := wrapYieldToExtractToolName(yield) + wrapped(tc.inputToolCall) + + // Verify the tool name was extracted correctly + if tc.expectedName != "" { + if captured.ToolCallName != tc.expectedName { + t.Errorf("Expected tool name '%s', got '%s'", tc.expectedName, captured.ToolCallName) + } + if captured.ToolCallInput != tc.expectedInput { + t.Errorf("Expected tool input '%s', got '%s'", tc.expectedInput, captured.ToolCallInput) + } + } + }) + } } \ No newline at end of file