From db0ff8bec4aa15a761d6498e174385774a4b2934 Mon Sep 17 00:00:00 2001 From: Jean-Laurent de Morlhon Date: Wed, 21 Jan 2026 14:35:34 +0100 Subject: [PATCH] Attach file instead of base64 them into the main message Signed-off-by: Jean-Laurent de Morlhon --- pkg/chat/chat.go | 35 ++ pkg/cli/runner.go | 85 ++--- pkg/cli/runner_attachment_test.go | 214 +++++++++++ pkg/model/provider/anthropic/beta_client.go | 2 +- .../provider/anthropic/beta_converter.go | 55 +-- pkg/model/provider/anthropic/client.go | 43 +-- pkg/model/provider/anthropic/client_test.go | 22 +- .../provider/anthropic/image_converter.go | 351 ++++++++++++++++++ .../anthropic/image_converter_test.go | 273 ++++++++++++++ pkg/model/provider/bedrock/convert.go | 54 ++- pkg/model/provider/bedrock/convert_test.go | 241 ++++++++++++ pkg/model/provider/gemini/client.go | 39 +- pkg/model/provider/gemini/image_converter.go | 168 +++++++++ .../provider/gemini/image_converter_test.go | 189 ++++++++++ .../provider/oaistream/image_converter.go | 115 ++++++ .../oaistream/image_converter_test.go | 283 ++++++++++++++ pkg/model/provider/oaistream/messages.go | 17 +- pkg/model/provider/oaistream/messages_test.go | 4 +- pkg/model/provider/openai/client.go | 41 +- 19 files changed, 2031 insertions(+), 200 deletions(-) create mode 100644 pkg/cli/runner_attachment_test.go create mode 100644 pkg/model/provider/anthropic/image_converter.go create mode 100644 pkg/model/provider/anthropic/image_converter_test.go create mode 100644 pkg/model/provider/bedrock/convert_test.go create mode 100644 pkg/model/provider/gemini/image_converter.go create mode 100644 pkg/model/provider/gemini/image_converter_test.go create mode 100644 pkg/model/provider/oaistream/image_converter.go create mode 100644 pkg/model/provider/oaistream/image_converter_test.go diff --git a/pkg/chat/chat.go b/pkg/chat/chat.go index bd7c25966..19aa3988a 100644 --- a/pkg/chat/chat.go +++ b/pkg/chat/chat.go @@ -26,9 +26,44 @@ const ( ImageURLDetailAuto ImageURLDetail = "auto" ) +// FileSourceType indicates how the file should be referenced in API calls +type FileSourceType string + +const ( + // FileSourceTypeNone means no file reference, use URL or base64 + FileSourceTypeNone FileSourceType = "" + // FileSourceTypeFileID means the file was uploaded and should be referenced by ID + FileSourceTypeFileID FileSourceType = "file_id" + // FileSourceTypeFileURI means the file was uploaded and should be referenced by URI (Gemini) + FileSourceTypeFileURI FileSourceType = "file_uri" + // FileSourceTypeLocalPath means the file is a local path that needs to be uploaded/converted + FileSourceTypeLocalPath FileSourceType = "local_path" +) + +// FileReference contains information about a file attachment +type FileReference struct { + // SourceType indicates how this file should be referenced + SourceType FileSourceType `json:"source_type,omitempty"` + // FileID is the provider-specific file identifier (for FileSourceTypeFileID) + FileID string `json:"file_id,omitempty"` + // FileURI is the file URI (for FileSourceTypeFileURI, used by Gemini) + FileURI string `json:"file_uri,omitempty"` + // LocalPath is the path to a local file (for FileSourceTypeLocalPath) + LocalPath string `json:"local_path,omitempty"` + // MimeType is the MIME type of the file + MimeType string `json:"mime_type,omitempty"` + // Provider identifies which provider this reference is for (when uploaded) + Provider string `json:"provider,omitempty"` +} + type MessageImageURL struct { + // URL contains a data URL (base64) or a public HTTP(S) URL URL string `json:"url,omitempty"` Detail ImageURLDetail `json:"detail,omitempty"` + + // FileRef contains file reference info when the image was uploaded via Files API + // or references a local file path that needs to be processed + FileRef *FileReference `json:"file_ref,omitempty"` } type Message struct { diff --git a/pkg/cli/runner.go b/pkg/cli/runner.go index f6942c41b..ae9d86e02 100644 --- a/pkg/cli/runner.go +++ b/pkg/cli/runner.go @@ -3,7 +3,6 @@ package cli import ( "cmp" "context" - "encoding/base64" "encoding/json" "fmt" "io" @@ -307,55 +306,29 @@ func ParseAttachCommand(userInput string) (messageText, attachPath string) { return messageText, attachPath } -// CreateUserMessageWithAttachment creates a user message with optional image attachment +// CreateUserMessageWithAttachment creates a user message with optional image attachment. +// Instead of converting to base64, this stores the file path for later processing +// by the provider (which may use Files API or base64 as appropriate). func CreateUserMessageWithAttachment(userContent, attachmentPath string) *session.Message { if attachmentPath == "" { return session.UserMessage(userContent) } - // Convert file to data URL - dataURL, err := fileToDataURL(attachmentPath) + // Resolve to absolute path + absPath, err := filepath.Abs(attachmentPath) if err != nil { - slog.Warn("Failed to attach file", "path", attachmentPath, "error", err) + slog.Warn("Failed to resolve attachment path", "path", attachmentPath, "error", err) return session.UserMessage(userContent) } - // Ensure we have some text content when attaching a file - textContent := cmp.Or(strings.TrimSpace(userContent), "Please analyze this attached file.") - - // Create message with multi-content including text and image - multiContent := []chat.MessagePart{ - { - Type: chat.MessagePartTypeText, - Text: textContent, - }, - { - Type: chat.MessagePartTypeImageURL, - ImageURL: &chat.MessageImageURL{ - URL: dataURL, - Detail: chat.ImageURLDetailAuto, - }, - }, - } - - return session.UserMessage("", multiContent...) -} - -// fileToDataURL converts a file to a data URL -func fileToDataURL(filePath string) (string, error) { // Check if file exists - if _, err := os.Stat(filePath); os.IsNotExist(err) { - return "", fmt.Errorf("file does not exist: %s", filePath) - } - - // Read file content - fileBytes, err := os.ReadFile(filePath) - if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + if _, err := os.Stat(absPath); os.IsNotExist(err) { + slog.Warn("Attachment file does not exist", "path", absPath) + return session.UserMessage(userContent) } - // Determine MIME type based on file extension - ext := strings.ToLower(filepath.Ext(filePath)) + // Determine MIME type from extension + ext := strings.ToLower(filepath.Ext(absPath)) var mimeType string switch ext { case ".jpg", ".jpeg": @@ -370,15 +343,39 @@ func fileToDataURL(filePath string) (string, error) { mimeType = "image/bmp" case ".svg": mimeType = "image/svg+xml" + case ".pdf": + mimeType = "application/pdf" default: - return "", fmt.Errorf("unsupported image format: %s", ext) + slog.Warn("Unsupported file format for attachment", "path", absPath, "ext", ext) + return session.UserMessage(userContent) } - // Encode to base64 - encoded := base64.StdEncoding.EncodeToString(fileBytes) + slog.Debug("Creating message with file attachment", + "path", absPath, + "mime_type", mimeType) + + // Ensure we have some text content when attaching a file + textContent := cmp.Or(strings.TrimSpace(userContent), "Please analyze this attached file.") - // Create data URL - dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, encoded) + // Create message with file reference (not base64) + // The provider will handle uploading via Files API or converting to base64 + multiContent := []chat.MessagePart{ + { + Type: chat.MessagePartTypeText, + Text: textContent, + }, + { + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + Detail: chat.ImageURLDetailAuto, + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: absPath, + MimeType: mimeType, + }, + }, + }, + } - return dataURL, nil + return session.UserMessage("", multiContent...) } diff --git a/pkg/cli/runner_attachment_test.go b/pkg/cli/runner_attachment_test.go new file mode 100644 index 000000000..5ee747286 --- /dev/null +++ b/pkg/cli/runner_attachment_test.go @@ -0,0 +1,214 @@ +package cli + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/chat" +) + +func TestCreateUserMessageWithAttachment(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + jpegPath := filepath.Join(tmpDir, "test.jpg") + pngPath := filepath.Join(tmpDir, "test.png") + gifPath := filepath.Join(tmpDir, "test.gif") + webpPath := filepath.Join(tmpDir, "test.webp") + pdfPath := filepath.Join(tmpDir, "test.pdf") + unsupportedPath := filepath.Join(tmpDir, "test.xyz") + + // Create test files + for _, path := range []string{jpegPath, pngPath, gifPath, webpPath, pdfPath, unsupportedPath} { + err := os.WriteFile(path, []byte("test data"), 0o644) + require.NoError(t, err) + } + + tests := []struct { + name string + userContent string + attachmentPath string + wantMultiContent bool + wantFileRef bool + wantMimeType string + wantDefaultPrompt bool + }{ + { + name: "no attachment", + userContent: "Hello world", + attachmentPath: "", + wantMultiContent: false, + }, + { + name: "jpeg attachment", + userContent: "Check this image", + attachmentPath: jpegPath, + wantMultiContent: true, + wantFileRef: true, + wantMimeType: "image/jpeg", + }, + { + name: "png attachment", + userContent: "Analyze this", + attachmentPath: pngPath, + wantMultiContent: true, + wantFileRef: true, + wantMimeType: "image/png", + }, + { + name: "gif attachment", + userContent: "What's in this gif?", + attachmentPath: gifPath, + wantMultiContent: true, + wantFileRef: true, + wantMimeType: "image/gif", + }, + { + name: "webp attachment", + userContent: "Describe this", + attachmentPath: webpPath, + wantMultiContent: true, + wantFileRef: true, + wantMimeType: "image/webp", + }, + { + name: "pdf attachment", + userContent: "Summarize this PDF", + attachmentPath: pdfPath, + wantMultiContent: true, + wantFileRef: true, + wantMimeType: "application/pdf", + }, + { + name: "attachment with empty content gets default prompt", + userContent: "", + attachmentPath: jpegPath, + wantMultiContent: true, + wantFileRef: true, + wantMimeType: "image/jpeg", + wantDefaultPrompt: true, + }, + { + name: "attachment with whitespace content gets default prompt", + userContent: " ", + attachmentPath: jpegPath, + wantMultiContent: true, + wantFileRef: true, + wantMimeType: "image/jpeg", + wantDefaultPrompt: true, + }, + { + name: "non-existent file falls back to text only", + userContent: "Hello", + attachmentPath: "/non/existent/file.jpg", + wantMultiContent: false, + }, + { + name: "unsupported format falls back to text only", + userContent: "Hello", + attachmentPath: unsupportedPath, + wantMultiContent: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + msg := CreateUserMessageWithAttachment(tt.userContent, tt.attachmentPath) + + require.NotNil(t, msg) + assert.Equal(t, chat.MessageRoleUser, msg.Message.Role) + + if tt.wantMultiContent { + assert.NotEmpty(t, msg.Message.MultiContent) + assert.Len(t, msg.Message.MultiContent, 2) // text + image + + // Check text part + textPart := msg.Message.MultiContent[0] + assert.Equal(t, chat.MessagePartTypeText, textPart.Type) + if tt.wantDefaultPrompt { + assert.Equal(t, "Please analyze this attached file.", textPart.Text) + } else { + assert.Equal(t, tt.userContent, textPart.Text) + } + + // Check image part + imagePart := msg.Message.MultiContent[1] + assert.Equal(t, chat.MessagePartTypeImageURL, imagePart.Type) + assert.NotNil(t, imagePart.ImageURL) + + if tt.wantFileRef { + assert.NotNil(t, imagePart.ImageURL.FileRef) + assert.Equal(t, chat.FileSourceTypeLocalPath, imagePart.ImageURL.FileRef.SourceType) + assert.NotEmpty(t, imagePart.ImageURL.FileRef.LocalPath) + assert.Equal(t, tt.wantMimeType, imagePart.ImageURL.FileRef.MimeType) + } + } else { + assert.Empty(t, msg.Message.MultiContent) + assert.Equal(t, tt.userContent, msg.Message.Content) + } + }) + } +} + +func TestParseAttachCommand(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantText string + wantAttachPath string + }{ + { + name: "no attach command", + input: "Hello world", + wantText: "Hello world", + wantAttachPath: "", + }, + { + name: "attach at start", + input: "/attach image.png describe this", + wantText: "describe this", + wantAttachPath: "image.png", + }, + { + name: "attach in middle", + input: "please /attach photo.jpg analyze it", + wantText: "please analyze it", + wantAttachPath: "photo.jpg", + }, + { + name: "attach only", + input: "/attach test.gif", + wantText: "", + wantAttachPath: "test.gif", + }, + { + name: "attach with path containing spaces handled", + input: "/attach my_image.png what is this?", + wantText: "what is this?", + wantAttachPath: "my_image.png", + }, + { + name: "multiline with attach", + input: "First line\n/attach image.jpg second part\nThird line", + wantText: "First line\nsecond part\nThird line", + wantAttachPath: "image.jpg", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + text, path := ParseAttachCommand(tt.input) + assert.Equal(t, tt.wantText, text) + assert.Equal(t, tt.wantAttachPath, path) + }) + } +} diff --git a/pkg/model/provider/anthropic/beta_client.go b/pkg/model/provider/anthropic/beta_client.go index b49b4a463..eb53df0fb 100644 --- a/pkg/model/provider/anthropic/beta_client.go +++ b/pkg/model/provider/anthropic/beta_client.go @@ -38,7 +38,7 @@ func (c *Client) createBetaStream( return nil, err } - converted := convertBetaMessages(messages) + converted := convertBetaMessagesWithClient(ctx, &client, messages) if err := validateAnthropicSequencingBeta(converted); err != nil { slog.Warn("Invalid message sequencing for Anthropic Beta API detected, attempting self-repair", "error", err) converted = repairAnthropicSequencingBeta(converted) diff --git a/pkg/model/provider/anthropic/beta_converter.go b/pkg/model/provider/anthropic/beta_converter.go index f70212a4c..62eb58dd7 100644 --- a/pkg/model/provider/anthropic/beta_converter.go +++ b/pkg/model/provider/anthropic/beta_converter.go @@ -1,6 +1,7 @@ package anthropic import ( + "context" "encoding/json" "strings" @@ -10,7 +11,13 @@ import ( "github.com/docker/cagent/pkg/tools" ) -// convertBetaMessages converts chat messages to Anthropic Beta API format +// convertBetaMessages is a backward-compatible wrapper that calls convertBetaMessagesWithClient +// with a nil client (falls back to base64 for local files) +func convertBetaMessages(messages []chat.Message) []anthropic.BetaMessageParam { + return convertBetaMessagesWithClient(context.Background(), nil, messages) +} + +// convertBetaMessagesWithClient converts chat messages to Anthropic Beta API format // Following Anthropic's extended thinking documentation with interleaved thinking enabled: // - Thinking blocks can appear anywhere in the conversation (not required to be first) // - Always include the complete, unmodified thinking block from previous assistant turns @@ -18,7 +25,7 @@ import ( // // Important: Anthropic API requires that all tool_result blocks corresponding to tool_use // blocks from the same assistant message MUST be grouped into a single user message. -func convertBetaMessages(messages []chat.Message) []anthropic.BetaMessageParam { +func convertBetaMessagesWithClient(ctx context.Context, client *anthropic.Client, messages []chat.Message) []anthropic.BetaMessageParam { var betaMessages []anthropic.BetaMessageParam for i := 0; i < len(messages); i++ { @@ -39,47 +46,9 @@ func convertBetaMessages(messages []chat.Message) []anthropic.BetaMessageParam { }) } } else if part.Type == chat.MessagePartTypeImageURL && part.ImageURL != nil { - if strings.HasPrefix(part.ImageURL.URL, "data:") { - parts := strings.SplitN(part.ImageURL.URL, ",", 2) - if len(parts) == 2 { - mediaTypePart := parts[0] - base64Data := parts[1] - var mediaType string - switch { - case strings.Contains(mediaTypePart, "image/jpeg"): - mediaType = "image/jpeg" - case strings.Contains(mediaTypePart, "image/png"): - mediaType = "image/png" - case strings.Contains(mediaTypePart, "image/gif"): - mediaType = "image/gif" - case strings.Contains(mediaTypePart, "image/webp"): - mediaType = "image/webp" - default: - mediaType = "image/jpeg" - } - // Use SDK types directly for better performance (avoids JSON round trip) - contentBlocks = append(contentBlocks, anthropic.BetaContentBlockParamUnion{ - OfImage: &anthropic.BetaImageBlockParam{ - Source: anthropic.BetaImageBlockParamSourceUnion{ - OfBase64: &anthropic.BetaBase64ImageSourceParam{ - Data: base64Data, - MediaType: anthropic.BetaBase64ImageSourceMediaType(mediaType), - }, - }, - }, - }) - } - } else if strings.HasPrefix(part.ImageURL.URL, "http://") || strings.HasPrefix(part.ImageURL.URL, "https://") { - // Support URL-based images - Anthropic can fetch images directly from URLs - contentBlocks = append(contentBlocks, anthropic.BetaContentBlockParamUnion{ - OfImage: &anthropic.BetaImageBlockParam{ - Source: anthropic.BetaImageBlockParamSourceUnion{ - OfURL: &anthropic.BetaURLImageSourceParam{ - URL: part.ImageURL.URL, - }, - }, - }, - }) + // Use the image converter which handles file refs, data URLs, and HTTP URLs + if imgBlock := convertBetaImagePart(ctx, client, part.ImageURL); imgBlock != nil { + contentBlocks = append(contentBlocks, *imgBlock) } } } diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 68d9f4b02..944e5e15f 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -232,7 +232,7 @@ func (c *Client) CreateChatCompletionStream( return nil, err } - converted := convertMessages(messages) + converted := convertMessages(ctx, messages) // Preflight validation to ensure tool_use/tool_result sequencing is valid if err := validateAnthropicSequencing(converted); err != nil { slog.Warn("Invalid message sequencing for Anthropic detected, attempting self-repair", "error", err) @@ -328,7 +328,7 @@ func (c *Client) CreateChatCompletionStream( return ad, nil } -func convertMessages(messages []chat.Message) []anthropic.MessageParam { +func convertMessages(ctx context.Context, messages []chat.Message) []anthropic.MessageParam { var anthropicMessages []anthropic.MessageParam // Track whether the last appended assistant message included tool_use blocks // so we can ensure the immediate next message is the grouped tool_result user message. @@ -350,42 +350,9 @@ func convertMessages(messages []chat.Message) []anthropic.MessageParam { contentBlocks = append(contentBlocks, anthropic.NewTextBlock(txt)) } } else if part.Type == chat.MessagePartTypeImageURL && part.ImageURL != nil { - // Anthropic expects base64 image data - // Extract base64 data from data URL - if strings.HasPrefix(part.ImageURL.URL, "data:") { - parts := strings.SplitN(part.ImageURL.URL, ",", 2) - if len(parts) == 2 { - // Extract media type from data URL - mediaTypePart := parts[0] - base64Data := parts[1] - - var mediaType string - switch { - case strings.Contains(mediaTypePart, "image/jpeg"): - mediaType = "image/jpeg" - case strings.Contains(mediaTypePart, "image/png"): - mediaType = "image/png" - case strings.Contains(mediaTypePart, "image/gif"): - mediaType = "image/gif" - case strings.Contains(mediaTypePart, "image/webp"): - mediaType = "image/webp" - default: - // Default to jpeg if not recognized - mediaType = "image/jpeg" - } - - // Use SDK helper with proper typed source for better performance - // (avoids JSON marshal/unmarshal round trip) - contentBlocks = append(contentBlocks, anthropic.NewImageBlock(anthropic.Base64ImageSourceParam{ - Data: base64Data, - MediaType: anthropic.Base64ImageSourceMediaType(mediaType), - })) - } - } else if strings.HasPrefix(part.ImageURL.URL, "http://") || strings.HasPrefix(part.ImageURL.URL, "https://") { - // Support URL-based images - Anthropic can fetch images directly from URLs - contentBlocks = append(contentBlocks, anthropic.NewImageBlock(anthropic.URLImageSourceParam{ - URL: part.ImageURL.URL, - })) + // Use the image converter which handles file refs, data URLs, and HTTP URLs + if imgBlock := convertImagePart(ctx, nil, part.ImageURL); imgBlock != nil { + contentBlocks = append(contentBlocks, *imgBlock) } } } diff --git a/pkg/model/provider/anthropic/client_test.go b/pkg/model/provider/anthropic/client_test.go index 03dcb8098..558bf3cfb 100644 --- a/pkg/model/provider/anthropic/client_test.go +++ b/pkg/model/provider/anthropic/client_test.go @@ -22,7 +22,7 @@ func TestConvertMessages_SkipEmptySystemText(t *testing.T) { Content: " \n\t ", }} - out := convertMessages(msgs) + out := convertMessages(t.Context(), msgs) assert.Empty(t, out) } @@ -32,7 +32,7 @@ func TestConvertMessages_SkipEmptyUserText_NoMultiContent(t *testing.T) { Content: " \n\t ", }} - out := convertMessages(msgs) + out := convertMessages(t.Context(), msgs) assert.Empty(t, out) } @@ -45,7 +45,7 @@ func TestConvertMessages_UserMultiContent_SkipEmptyText_KeepImage(t *testing.T) }, }} - out := convertMessages(msgs) + out := convertMessages(t.Context(), msgs) require.Len(t, out, 1) b, err := json.Marshal(out[0]) @@ -71,7 +71,7 @@ func TestConvertMessages_SkipEmptyAssistantText_NoToolCalls(t *testing.T) { Content: " \t\n ", }} - out := convertMessages(msgs) + out := convertMessages(t.Context(), msgs) assert.Empty(t, out) } @@ -84,7 +84,7 @@ func TestConvertMessages_AssistantToolCalls_NoText_IncludesToolUse(t *testing.T) }, }} - out := convertMessages(msgs) + out := convertMessages(t.Context(), msgs) require.Len(t, out, 1) b, err := json.Marshal(out[0]) @@ -112,7 +112,7 @@ func TestSystemMessages_AreExtractedAndNotInMessageList(t *testing.T) { assert.Equal(t, "system rules here", strings.TrimSpace(sys[0].Text)) // System role messages must not appear in the anthropic messages list - out := convertMessages(msgs) + out := convertMessages(t.Context(), msgs) assert.Len(t, out, 1) } @@ -128,7 +128,7 @@ func TestSystemMessages_MultipleExtractedAndExcludedFromMessageList(t *testing.T assert.Equal(t, "sys A", strings.TrimSpace(sys[0].Text)) assert.Equal(t, "sys B", strings.TrimSpace(sys[1].Text)) - out := convertMessages(msgs) + out := convertMessages(t.Context(), msgs) assert.Len(t, out, 1) } @@ -148,7 +148,7 @@ func TestSystemMessages_InterspersedExtractedAndExcluded(t *testing.T) { assert.Equal(t, "S2", strings.TrimSpace(sys[1].Text)) // Converted messages must exclude system roles and preserve order of others - out := convertMessages(msgs) + out := convertMessages(t.Context(), msgs) require.Len(t, out, 3) expectedRoles := []string{"user", "assistant", "user"} for i, expected := range expectedRoles { @@ -173,7 +173,7 @@ func TestSequencingRepair_Standard(t *testing.T) { {Role: chat.MessageRoleUser, Content: "continue"}, } - converted := convertMessages(msgs) + converted := convertMessages(t.Context(), msgs) err := validateAnthropicSequencing(converted) require.Error(t, err) @@ -212,7 +212,7 @@ func TestConvertMessages_DropOrphanToolResults_NoPrecedingToolUse(t *testing.T) {Role: chat.MessageRoleUser, Content: "continue"}, } - converted := convertMessages(msgs) + converted := convertMessages(t.Context(), msgs) // Expect only the two user text messages to appear require.Len(t, converted, 2) @@ -246,7 +246,7 @@ func TestConvertMessages_GroupToolResults_AfterAssistantToolUse(t *testing.T) { {Role: chat.MessageRoleUser, Content: "ok"}, } - converted := convertMessages(msgs) + converted := convertMessages(t.Context(), msgs) // Expect: user(start), assistant(tool_use), user(grouped tool_result), user(ok) require.Len(t, converted, 4) diff --git a/pkg/model/provider/anthropic/image_converter.go b/pkg/model/provider/anthropic/image_converter.go new file mode 100644 index 000000000..9e3d7b3a2 --- /dev/null +++ b/pkg/model/provider/anthropic/image_converter.go @@ -0,0 +1,351 @@ +package anthropic + +import ( + "context" + "encoding/base64" + "fmt" + "log/slog" + "os" + "strings" + "sync" + + "github.com/anthropics/anthropic-sdk-go" + + "github.com/docker/cagent/pkg/chat" +) + +// fileUploadCache caches file IDs to avoid re-uploading the same file +var fileUploadCache = struct { + sync.RWMutex + cache map[string]string // localPath -> fileID +}{cache: make(map[string]string)} + +// convertImagePart converts a MessageImageURL to an Anthropic image block. +// It handles file references (uploading via Files API if possible), +// base64 data URLs, and HTTP(S) URLs. +func convertImagePart(ctx context.Context, client *anthropic.Client, imageURL *chat.MessageImageURL) *anthropic.ContentBlockParamUnion { + if imageURL == nil { + return nil + } + + // Handle file reference (from /attach command) + if imageURL.FileRef != nil { + return convertFileRefToImageBlock(ctx, client, imageURL.FileRef) + } + + // Handle data URL (base64) + if strings.HasPrefix(imageURL.URL, "data:") { + return convertDataURLToImageBlock(imageURL.URL) + } + + // Handle HTTP(S) URL + if strings.HasPrefix(imageURL.URL, "http://") || strings.HasPrefix(imageURL.URL, "https://") { + return &anthropic.ContentBlockParamUnion{ + OfImage: &anthropic.ImageBlockParam{ + Source: anthropic.ImageBlockParamSourceUnion{ + OfURL: &anthropic.URLImageSourceParam{ + URL: imageURL.URL, + }, + }, + }, + } + } + + return nil +} + +// convertFileRefToImageBlock handles file references, uploading via Files API if available +func convertFileRefToImageBlock(ctx context.Context, client *anthropic.Client, fileRef *chat.FileReference) *anthropic.ContentBlockParamUnion { + if fileRef == nil { + return nil + } + + switch fileRef.SourceType { + case chat.FileSourceTypeFileID: + // Already uploaded to Anthropic, use file ID directly + // Note: File ID support is only available in the Beta API + // The standard API will fall back to base64 + slog.Debug("Using existing file ID", "file_id", fileRef.FileID) + // Standard API doesn't support file IDs, so we need to fall back + // For now, log a warning and return nil (the beta converter handles this) + slog.Warn("File ID references not supported in standard Anthropic API, skipping") + return nil + + case chat.FileSourceTypeLocalPath: + // Try to upload via Files API, fall back to base64 + return uploadOrConvertLocalFile(ctx, client, fileRef.LocalPath, fileRef.MimeType) + + default: + slog.Warn("Unknown file source type", "type", fileRef.SourceType) + return nil + } +} + +// uploadOrConvertLocalFile attempts to upload a local file via Files API. +// If that fails or no client is provided, falls back to base64 encoding. +func uploadOrConvertLocalFile(_ context.Context, _ *anthropic.Client, localPath, mimeType string) *anthropic.ContentBlockParamUnion { + // For standard API, we always use base64 since it doesn't support file IDs + // The Files API upload would be wasted since we can't reference it + return convertLocalFileToBase64Block(localPath, mimeType) +} + +// convertLocalFileToBase64Block reads a local file and converts it to a base64 image block +func convertLocalFileToBase64Block(localPath, mimeType string) *anthropic.ContentBlockParamUnion { + data, err := os.ReadFile(localPath) + if err != nil { + slog.Warn("Failed to read local file", "path", localPath, "error", err) + return nil + } + + encoded := base64.StdEncoding.EncodeToString(data) + + if mimeType == "" { + mimeType = "image/jpeg" // Default + } + + slog.Debug("Converted local file to base64", "path", localPath, "size", len(data)) + + return &anthropic.ContentBlockParamUnion{ + OfImage: &anthropic.ImageBlockParam{ + Source: anthropic.ImageBlockParamSourceUnion{ + OfBase64: &anthropic.Base64ImageSourceParam{ + Data: encoded, + MediaType: anthropic.Base64ImageSourceMediaType(mimeType), + }, + }, + }, + } +} + +// convertDataURLToImageBlock parses a data URL and converts it to an image block +func convertDataURLToImageBlock(dataURL string) *anthropic.ContentBlockParamUnion { + parts := strings.SplitN(dataURL, ",", 2) + if len(parts) != 2 { + return nil + } + + mediaTypePart := parts[0] + base64Data := parts[1] + + var mediaType string + switch { + case strings.Contains(mediaTypePart, "image/jpeg"): + mediaType = "image/jpeg" + case strings.Contains(mediaTypePart, "image/png"): + mediaType = "image/png" + case strings.Contains(mediaTypePart, "image/gif"): + mediaType = "image/gif" + case strings.Contains(mediaTypePart, "image/webp"): + mediaType = "image/webp" + default: + mediaType = "image/jpeg" + } + + return &anthropic.ContentBlockParamUnion{ + OfImage: &anthropic.ImageBlockParam{ + Source: anthropic.ImageBlockParamSourceUnion{ + OfBase64: &anthropic.Base64ImageSourceParam{ + Data: base64Data, + MediaType: anthropic.Base64ImageSourceMediaType(mediaType), + }, + }, + }, + } +} + +// Beta API versions that support file references + +// convertBetaImagePart converts a MessageImageURL to a Beta API image block. +// It handles file references (uploading via Files API), base64 data URLs, and HTTP(S) URLs. +func convertBetaImagePart(ctx context.Context, client *anthropic.Client, imageURL *chat.MessageImageURL) *anthropic.BetaContentBlockParamUnion { + if imageURL == nil { + return nil + } + + // Handle file reference (from /attach command) + if imageURL.FileRef != nil { + return convertBetaFileRefToImageBlock(ctx, client, imageURL.FileRef) + } + + // Handle data URL (base64) + if strings.HasPrefix(imageURL.URL, "data:") { + return convertBetaDataURLToImageBlock(imageURL.URL) + } + + // Handle HTTP(S) URL + if strings.HasPrefix(imageURL.URL, "http://") || strings.HasPrefix(imageURL.URL, "https://") { + return &anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfURL: &anthropic.BetaURLImageSourceParam{ + URL: imageURL.URL, + }, + }, + }, + } + } + + return nil +} + +// convertBetaFileRefToImageBlock handles file references for the Beta API +func convertBetaFileRefToImageBlock(ctx context.Context, client *anthropic.Client, fileRef *chat.FileReference) *anthropic.BetaContentBlockParamUnion { + if fileRef == nil { + return nil + } + + switch fileRef.SourceType { + case chat.FileSourceTypeFileID: + // Already uploaded, use file ID directly + slog.Debug("Using existing file ID for Beta API", "file_id", fileRef.FileID) + return &anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfFile: &anthropic.BetaFileImageSourceParam{ + FileID: fileRef.FileID, + }, + }, + }, + } + + case chat.FileSourceTypeLocalPath: + // Try to upload via Files API, fall back to base64 + return uploadOrConvertBetaLocalFile(ctx, client, fileRef.LocalPath, fileRef.MimeType) + + default: + slog.Warn("Unknown file source type", "type", fileRef.SourceType) + return nil + } +} + +// uploadOrConvertBetaLocalFile attempts to upload a local file via Files API for Beta API. +// If that fails, falls back to base64 encoding. +func uploadOrConvertBetaLocalFile(ctx context.Context, client *anthropic.Client, localPath, mimeType string) *anthropic.BetaContentBlockParamUnion { + // Check cache first + fileUploadCache.RLock() + if fileID, ok := fileUploadCache.cache[localPath]; ok { + fileUploadCache.RUnlock() + slog.Debug("Using cached file ID", "path", localPath, "file_id", fileID) + return &anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfFile: &anthropic.BetaFileImageSourceParam{ + FileID: fileID, + }, + }, + }, + } + } + fileUploadCache.RUnlock() + + // Try to upload via Files API + if client != nil { + fileID, err := uploadFileToAnthropic(ctx, client, localPath) + if err == nil { + // Cache the file ID + fileUploadCache.Lock() + fileUploadCache.cache[localPath] = fileID + fileUploadCache.Unlock() + + slog.Debug("Uploaded file to Anthropic Files API", "path", localPath, "file_id", fileID) + return &anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfFile: &anthropic.BetaFileImageSourceParam{ + FileID: fileID, + }, + }, + }, + } + } + slog.Warn("Failed to upload file to Anthropic, falling back to base64", "path", localPath, "error", err) + } + + // Fall back to base64 + return convertBetaLocalFileToBase64Block(localPath, mimeType) +} + +// uploadFileToAnthropic uploads a file to Anthropic's Files API and returns the file ID +func uploadFileToAnthropic(ctx context.Context, client *anthropic.Client, localPath string) (string, error) { + file, err := os.Open(localPath) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + params := anthropic.BetaFileUploadParams{ + File: file, + Betas: []anthropic.AnthropicBeta{anthropic.AnthropicBetaFilesAPI2025_04_14}, + } + + result, err := client.Beta.Files.Upload(ctx, params) + if err != nil { + return "", fmt.Errorf("failed to upload file: %w", err) + } + + return result.ID, nil +} + +// convertBetaLocalFileToBase64Block reads a local file and converts it to a Beta API base64 image block +func convertBetaLocalFileToBase64Block(localPath, mimeType string) *anthropic.BetaContentBlockParamUnion { + data, err := os.ReadFile(localPath) + if err != nil { + slog.Warn("Failed to read local file", "path", localPath, "error", err) + return nil + } + + encoded := base64.StdEncoding.EncodeToString(data) + + if mimeType == "" { + mimeType = "image/jpeg" // Default + } + + slog.Debug("Converted local file to base64 (Beta API)", "path", localPath, "size", len(data)) + + return &anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfBase64: &anthropic.BetaBase64ImageSourceParam{ + Data: encoded, + MediaType: anthropic.BetaBase64ImageSourceMediaType(mimeType), + }, + }, + }, + } +} + +// convertBetaDataURLToImageBlock parses a data URL and converts it to a Beta API image block +func convertBetaDataURLToImageBlock(dataURL string) *anthropic.BetaContentBlockParamUnion { + parts := strings.SplitN(dataURL, ",", 2) + if len(parts) != 2 { + return nil + } + + mediaTypePart := parts[0] + base64Data := parts[1] + + var mediaType string + switch { + case strings.Contains(mediaTypePart, "image/jpeg"): + mediaType = "image/jpeg" + case strings.Contains(mediaTypePart, "image/png"): + mediaType = "image/png" + case strings.Contains(mediaTypePart, "image/gif"): + mediaType = "image/gif" + case strings.Contains(mediaTypePart, "image/webp"): + mediaType = "image/webp" + default: + mediaType = "image/jpeg" + } + + return &anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfBase64: &anthropic.BetaBase64ImageSourceParam{ + Data: base64Data, + MediaType: anthropic.BetaBase64ImageSourceMediaType(mediaType), + }, + }, + }, + } +} diff --git a/pkg/model/provider/anthropic/image_converter_test.go b/pkg/model/provider/anthropic/image_converter_test.go new file mode 100644 index 000000000..246992653 --- /dev/null +++ b/pkg/model/provider/anthropic/image_converter_test.go @@ -0,0 +1,273 @@ +package anthropic + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/chat" +) + +func TestConvertImagePart(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + imageURL *chat.MessageImageURL + wantNil bool + }{ + { + name: "nil imageURL", + imageURL: nil, + wantNil: true, + }, + { + name: "data URL jpeg", + imageURL: &chat.MessageImageURL{ + URL: "data:image/jpeg;base64,/9j/4AAQSkZJRg==", + }, + wantNil: false, + }, + { + name: "data URL png", + imageURL: &chat.MessageImageURL{ + URL: "data:image/png;base64,iVBORw0KGgo=", + }, + wantNil: false, + }, + { + name: "http URL", + imageURL: &chat.MessageImageURL{ + URL: "http://example.com/image.png", + }, + wantNil: false, + }, + { + name: "https URL", + imageURL: &chat.MessageImageURL{ + URL: "https://example.com/image.jpg", + }, + wantNil: false, + }, + { + name: "invalid data URL format", + imageURL: &chat.MessageImageURL{ + URL: "data:image/jpeg", // missing comma and data + }, + wantNil: true, + }, + { + name: "empty URL", + imageURL: &chat.MessageImageURL{ + URL: "", + }, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertImagePart(t.Context(), nil, tt.imageURL) + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + } + }) + } +} + +func TestConvertImagePartWithFileRef(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.jpg") + testImageData := []byte{0xFF, 0xD8, 0xFF, 0xE0} // Minimal JPEG header + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + imageURL *chat.MessageImageURL + wantNil bool + }{ + { + name: "local file path", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: testImagePath, + MimeType: "image/jpeg", + }, + }, + wantNil: false, + }, + { + name: "non-existent local file", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: "/non/existent/path.jpg", + MimeType: "image/jpeg", + }, + }, + wantNil: true, + }, + { + name: "file ID (standard API doesn't support)", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileID, + FileID: "file-abc123", + MimeType: "image/jpeg", + }, + }, + wantNil: true, // Standard API doesn't support file IDs + }, + { + name: "unknown source type", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: "unknown", + }, + }, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertImagePart(t.Context(), nil, tt.imageURL) + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + } + }) + } +} + +func TestConvertBetaImagePartWithFileRef(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.png") + testImageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + imageURL *chat.MessageImageURL + wantNil bool + }{ + { + name: "local file path", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: testImagePath, + MimeType: "image/png", + }, + }, + wantNil: false, + }, + { + name: "file ID reference", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileID, + FileID: "file-abc123", + MimeType: "image/jpeg", + }, + }, + wantNil: false, // Beta API supports file IDs + }, + { + name: "data URL", + imageURL: &chat.MessageImageURL{ + URL: "data:image/png;base64,iVBORw0KGgo=", + }, + wantNil: false, + }, + { + name: "https URL", + imageURL: &chat.MessageImageURL{ + URL: "https://example.com/image.png", + }, + wantNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertBetaImagePart(t.Context(), nil, tt.imageURL) + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + } + }) + } +} + +func TestConvertDataURLToImageBlock(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dataURL string + wantNil bool + }{ + { + name: "valid jpeg", + dataURL: "data:image/jpeg;base64,/9j/4AAQSkZJRg==", + wantNil: false, + }, + { + name: "valid png", + dataURL: "data:image/png;base64,iVBORw0KGgo=", + wantNil: false, + }, + { + name: "valid gif", + dataURL: "data:image/gif;base64,R0lGODlh", + wantNil: false, + }, + { + name: "valid webp", + dataURL: "data:image/webp;base64,UklGR", + wantNil: false, + }, + { + name: "missing comma", + dataURL: "data:image/jpeg;base64", + wantNil: true, + }, + { + name: "empty data", + dataURL: "data:image/jpeg;base64,", + wantNil: false, // Empty base64 is technically valid + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertDataURLToImageBlock(tt.dataURL) + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + } + }) + } +} diff --git a/pkg/model/provider/bedrock/convert.go b/pkg/model/provider/bedrock/convert.go index 77bebf175..745e189c1 100644 --- a/pkg/model/provider/bedrock/convert.go +++ b/pkg/model/provider/bedrock/convert.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "log/slog" + "os" "strings" "github.com/aws/aws-sdk-go-v2/aws" @@ -149,31 +150,58 @@ func convertUserContent(msg *chat.Message) []types.ContentBlock { } func convertImageURL(imageURL *chat.MessageImageURL) types.ContentBlock { - if !strings.HasPrefix(imageURL.URL, "data:") { + if imageURL == nil { return nil } - parts := strings.SplitN(imageURL.URL, ",", 2) - if len(parts) != 2 { - return nil - } + var imageData []byte + var mimeType string - // Decode base64 data - imageData, err := base64.StdEncoding.DecodeString(parts[1]) - if err != nil { + switch { + case imageURL.FileRef != nil: + // Handle file reference (from /attach command) + switch imageURL.FileRef.SourceType { + case chat.FileSourceTypeLocalPath: + data, err := os.ReadFile(imageURL.FileRef.LocalPath) + if err != nil { + slog.Warn("Failed to read local file for Bedrock", "path", imageURL.FileRef.LocalPath, "error", err) + return nil + } + imageData = data + mimeType = imageURL.FileRef.MimeType + slog.Debug("Converted local file to bytes for Bedrock", "path", imageURL.FileRef.LocalPath, "size", len(data)) + default: + slog.Warn("Unsupported file source type for Bedrock", "type", imageURL.FileRef.SourceType) + return nil + } + case strings.HasPrefix(imageURL.URL, "data:"): + // Handle data URL (base64) + parts := strings.SplitN(imageURL.URL, ",", 2) + if len(parts) != 2 { + return nil + } + + data, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return nil + } + imageData = data + mimeType = parts[0] + default: + // Bedrock doesn't support URL-based images return nil } - // Determine format from media type + // Determine format from MIME type var format types.ImageFormat switch { - case strings.Contains(parts[0], "image/jpeg"): + case strings.Contains(mimeType, "image/jpeg"): format = types.ImageFormatJpeg - case strings.Contains(parts[0], "image/png"): + case strings.Contains(mimeType, "image/png"): format = types.ImageFormatPng - case strings.Contains(parts[0], "image/gif"): + case strings.Contains(mimeType, "image/gif"): format = types.ImageFormatGif - case strings.Contains(parts[0], "image/webp"): + case strings.Contains(mimeType, "image/webp"): format = types.ImageFormatWebp default: format = types.ImageFormatJpeg diff --git a/pkg/model/provider/bedrock/convert_test.go b/pkg/model/provider/bedrock/convert_test.go new file mode 100644 index 000000000..ad4db0169 --- /dev/null +++ b/pkg/model/provider/bedrock/convert_test.go @@ -0,0 +1,241 @@ +package bedrock + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/chat" +) + +func TestConvertImageURL(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.jpg") + testImageData := []byte{0xFF, 0xD8, 0xFF, 0xE0} // Minimal JPEG header + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + imageURL *chat.MessageImageURL + wantNil bool + }{ + { + name: "nil imageURL", + imageURL: nil, + wantNil: true, + }, + { + name: "data URL jpeg", + imageURL: &chat.MessageImageURL{ + URL: "data:image/jpeg;base64,/9j/4AAQSkZJRg==", + }, + wantNil: false, + }, + { + name: "data URL png", + imageURL: &chat.MessageImageURL{ + URL: "data:image/png;base64,iVBORw0KGgo=", + }, + wantNil: false, + }, + { + name: "data URL gif", + imageURL: &chat.MessageImageURL{ + URL: "data:image/gif;base64,R0lGODlh", + }, + wantNil: false, + }, + { + name: "data URL webp", + imageURL: &chat.MessageImageURL{ + URL: "data:image/webp;base64,UklGRlYAAABXRUJQ", + }, + wantNil: false, + }, + { + name: "http URL not supported", + imageURL: &chat.MessageImageURL{ + URL: "http://example.com/image.png", + }, + wantNil: true, // Bedrock doesn't support URL-based images + }, + { + name: "https URL not supported", + imageURL: &chat.MessageImageURL{ + URL: "https://example.com/image.jpg", + }, + wantNil: true, // Bedrock doesn't support URL-based images + }, + { + name: "local file path", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: testImagePath, + MimeType: "image/jpeg", + }, + }, + wantNil: false, + }, + { + name: "non-existent local file", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: "/non/existent/path.jpg", + MimeType: "image/jpeg", + }, + }, + wantNil: true, + }, + { + name: "file ID not supported", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileID, + FileID: "file-abc123", + MimeType: "image/jpeg", + }, + }, + wantNil: true, + }, + { + name: "file URI not supported", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileURI, + FileURI: "https://example.com/file", + MimeType: "image/jpeg", + }, + }, + wantNil: true, + }, + { + name: "invalid data URL format", + imageURL: &chat.MessageImageURL{ + URL: "data:image/jpeg", // missing comma and data + }, + wantNil: true, + }, + { + name: "empty URL", + imageURL: &chat.MessageImageURL{ + URL: "", + }, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertImageURL(tt.imageURL) + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + } + }) + } +} + +func TestConvertUserContent(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.png") + testImageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + msg *chat.Message + wantCount int + }{ + { + name: "text only", + msg: &chat.Message{ + Content: "Hello world", + }, + wantCount: 1, + }, + { + name: "empty content", + msg: &chat.Message{ + Content: " ", + }, + wantCount: 0, + }, + { + name: "multi-content text only", + msg: &chat.Message{ + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "Hello"}, + {Type: chat.MessagePartTypeText, Text: "World"}, + }, + }, + wantCount: 2, + }, + { + name: "multi-content with image", + msg: &chat.Message{ + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "Check this"}, + { + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + URL: "data:image/jpeg;base64,/9j/4AAQSkZJRg==", + }, + }, + }, + }, + wantCount: 2, + }, + { + name: "multi-content with local file", + msg: &chat.Message{ + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "Check this"}, + { + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: testImagePath, + MimeType: "image/png", + }, + }, + }, + }, + }, + wantCount: 2, + }, + { + name: "skip nil imageURL", + msg: &chat.Message{ + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "Hello"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: nil}, + }, + }, + wantCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertUserContent(tt.msg) + assert.Len(t, result, tt.wantCount) + }) + } +} diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 3642d6c0d..82488eb36 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -3,7 +3,6 @@ package gemini import ( "cmp" "context" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -164,6 +163,12 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // convertMessagesToGemini converts chat.Messages into Gemini Contents func convertMessagesToGemini(messages []chat.Message) []*genai.Content { + return convertMessagesToGeminiWithClient(context.Background(), nil, messages) +} + +// convertMessagesToGeminiWithClient converts chat.Messages into Gemini Contents +// using the provided client for Files API uploads +func convertMessagesToGeminiWithClient(ctx context.Context, client *genai.Client, messages []chat.Message) []*genai.Content { contents := make([]*genai.Content, 0, len(messages)) for i := range messages { msg := &messages[i] @@ -212,7 +217,7 @@ func convertMessagesToGemini(messages []chat.Message) []*genai.Content { // Handle regular messages if len(msg.MultiContent) > 0 { - parts := convertMultiContent(msg.MultiContent, msg.ThoughtSignature) + parts := convertMultiContentWithClient(ctx, client, msg.MultiContent, msg.ThoughtSignature) if len(parts) > 0 { contents = append(contents, genai.NewContentFromParts(parts, role)) } @@ -241,15 +246,17 @@ func newTextPartWithSignature(text string, signature []byte) *genai.Part { return part } -// convertMultiContent converts multi-part content to Gemini parts -func convertMultiContent(multiContent []chat.MessagePart, thoughtSignature []byte) []*genai.Part { +// convertMultiContentWithClient converts multi-part content to Gemini parts +// using the provided client for Files API uploads +func convertMultiContentWithClient(ctx context.Context, client *genai.Client, multiContent []chat.MessagePart, thoughtSignature []byte) []*genai.Part { parts := make([]*genai.Part, 0, len(multiContent)) for _, part := range multiContent { switch part.Type { case chat.MessagePartTypeText: parts = append(parts, newTextPartWithSignature(part.Text, thoughtSignature)) case chat.MessagePartTypeImageURL: - if imgPart := convertImageURLToPart(part.ImageURL); imgPart != nil { + // Use the image converter which handles file refs, data URLs, and HTTP URLs + if imgPart := convertImageURLToPartWithClient(ctx, client, part.ImageURL); imgPart != nil { parts = append(parts, imgPart) } } @@ -257,28 +264,6 @@ func convertMultiContent(multiContent []chat.MessagePart, thoughtSignature []byt return parts } -// convertImageURLToPart converts an image URL to a Gemini Part -// Supports data URLs with base64-encoded image data -func convertImageURLToPart(imageURL *chat.MessageImageURL) *genai.Part { - if imageURL == nil || !strings.HasPrefix(imageURL.URL, "data:") { - return nil - } - - // Parse data URL format: data:[][;base64], - urlParts := strings.SplitN(imageURL.URL, ",", 2) - if len(urlParts) != 2 { - return nil - } - - imageData, err := base64.StdEncoding.DecodeString(urlParts[1]) - if err != nil { - return nil - } - - mimeType := extractMimeType(urlParts[0]) - return genai.NewPartFromBytes(imageData, mimeType) -} - // extractMimeType extracts the MIME type from a data URL prefix func extractMimeType(dataURLPrefix string) string { for _, mimeType := range []string{"image/jpeg", "image/png", "image/gif", "image/webp"} { diff --git a/pkg/model/provider/gemini/image_converter.go b/pkg/model/provider/gemini/image_converter.go new file mode 100644 index 000000000..6f0dbd843 --- /dev/null +++ b/pkg/model/provider/gemini/image_converter.go @@ -0,0 +1,168 @@ +package gemini + +import ( + "context" + "encoding/base64" + "log/slog" + "os" + "strings" + "sync" + + "google.golang.org/genai" + + "github.com/docker/cagent/pkg/chat" +) + +// fileUploadCache caches file URIs to avoid re-uploading the same file +var fileUploadCache = struct { + sync.RWMutex + cache map[string]string // localPath -> fileURI +}{cache: make(map[string]string)} + +// convertImageURLToPart converts an image URL to a Gemini Part. +// It handles file references (uploading via Files API if possible), +// base64 data URLs, and local files. +func convertImageURLToPartWithClient(ctx context.Context, client *genai.Client, imageURL *chat.MessageImageURL) *genai.Part { + if imageURL == nil { + return nil + } + + // Handle file reference (from /attach command) + if imageURL.FileRef != nil { + return convertFileRefToPart(ctx, client, imageURL.FileRef) + } + + // Handle data URL (base64) + if strings.HasPrefix(imageURL.URL, "data:") { + return convertDataURLToPart(imageURL.URL) + } + + // Handle HTTP(S) URL - Gemini can fetch from URLs + if strings.HasPrefix(imageURL.URL, "http://") || strings.HasPrefix(imageURL.URL, "https://") { + return genai.NewPartFromURI(imageURL.URL, extractMimeTypeFromURL(imageURL.URL)) + } + + return nil +} + +// convertFileRefToPart handles file references, uploading via Files API if available +func convertFileRefToPart(ctx context.Context, client *genai.Client, fileRef *chat.FileReference) *genai.Part { + if fileRef == nil { + return nil + } + + switch fileRef.SourceType { + case chat.FileSourceTypeFileURI: + // Already uploaded to Gemini, use URI directly + slog.Debug("Using existing file URI", "uri", fileRef.FileURI) + return genai.NewPartFromURI(fileRef.FileURI, fileRef.MimeType) + + case chat.FileSourceTypeFileID: + // File ID from another provider - need to upload to Gemini + slog.Warn("File ID from another provider not supported for Gemini, skipping", "file_id", fileRef.FileID) + return nil + + case chat.FileSourceTypeLocalPath: + // Try to upload via Files API, fall back to base64 + return uploadOrConvertLocalFile(ctx, client, fileRef.LocalPath, fileRef.MimeType) + + default: + slog.Warn("Unknown file source type", "type", fileRef.SourceType) + return nil + } +} + +// uploadOrConvertLocalFile attempts to upload a local file via Gemini Files API. +// If that fails, falls back to reading the file and sending as bytes. +func uploadOrConvertLocalFile(ctx context.Context, client *genai.Client, localPath, mimeType string) *genai.Part { + // Check cache first + fileUploadCache.RLock() + if fileURI, ok := fileUploadCache.cache[localPath]; ok { + fileUploadCache.RUnlock() + slog.Debug("Using cached file URI", "path", localPath, "uri", fileURI) + return genai.NewPartFromURI(fileURI, mimeType) + } + fileUploadCache.RUnlock() + + // Try to upload via Files API + if client != nil { + fileURI, err := uploadFileToGemini(ctx, client, localPath, mimeType) + if err == nil { + // Cache the file URI + fileUploadCache.Lock() + fileUploadCache.cache[localPath] = fileURI + fileUploadCache.Unlock() + + slog.Debug("Uploaded file to Gemini Files API", "path", localPath, "uri", fileURI) + return genai.NewPartFromURI(fileURI, mimeType) + } + slog.Warn("Failed to upload file to Gemini, falling back to bytes", "path", localPath, "error", err) + } + + // Fall back to reading file and sending as bytes + return convertLocalFileToBytesPart(localPath, mimeType) +} + +// uploadFileToGemini uploads a file to Gemini's Files API and returns the file URI +func uploadFileToGemini(ctx context.Context, client *genai.Client, localPath, mimeType string) (string, error) { + config := &genai.UploadFileConfig{ + MIMEType: mimeType, + } + + file, err := client.Files.UploadFromPath(ctx, localPath, config) + if err != nil { + return "", err + } + + return file.URI, nil +} + +// convertLocalFileToBytesPart reads a local file and converts it to a bytes Part +func convertLocalFileToBytesPart(localPath, mimeType string) *genai.Part { + data, err := os.ReadFile(localPath) + if err != nil { + slog.Warn("Failed to read local file", "path", localPath, "error", err) + return nil + } + + if mimeType == "" { + mimeType = "image/jpeg" // Default + } + + slog.Debug("Converted local file to bytes", "path", localPath, "size", len(data)) + return genai.NewPartFromBytes(data, mimeType) +} + +// convertDataURLToPart parses a data URL and converts it to a Gemini Part +func convertDataURLToPart(dataURL string) *genai.Part { + // Parse data URL format: data:[][;base64], + urlParts := strings.SplitN(dataURL, ",", 2) + if len(urlParts) != 2 { + return nil + } + + imageData, err := base64.StdEncoding.DecodeString(urlParts[1]) + if err != nil { + return nil + } + + mimeType := extractMimeType(urlParts[0]) + return genai.NewPartFromBytes(imageData, mimeType) +} + +// extractMimeTypeFromURL tries to determine MIME type from a URL +func extractMimeTypeFromURL(url string) string { + lowerURL := strings.ToLower(url) + switch { + case strings.HasSuffix(lowerURL, ".jpg"), strings.HasSuffix(lowerURL, ".jpeg"): + return "image/jpeg" + case strings.HasSuffix(lowerURL, ".png"): + return "image/png" + case strings.HasSuffix(lowerURL, ".gif"): + return "image/gif" + case strings.HasSuffix(lowerURL, ".webp"): + return "image/webp" + default: + return "image/jpeg" // Default + } +} diff --git a/pkg/model/provider/gemini/image_converter_test.go b/pkg/model/provider/gemini/image_converter_test.go new file mode 100644 index 000000000..52774d4c9 --- /dev/null +++ b/pkg/model/provider/gemini/image_converter_test.go @@ -0,0 +1,189 @@ +package gemini + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/chat" +) + +func TestConvertImageURLToPartWithClient(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.jpg") + testImageData := []byte{0xFF, 0xD8, 0xFF, 0xE0} // Minimal JPEG header + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + imageURL *chat.MessageImageURL + wantNil bool + }{ + { + name: "nil imageURL", + imageURL: nil, + wantNil: true, + }, + { + name: "data URL", + imageURL: &chat.MessageImageURL{ + URL: "data:image/jpeg;base64,/9j/4AAQSkZJRg==", + }, + wantNil: false, + }, + { + name: "http URL", + imageURL: &chat.MessageImageURL{ + URL: "http://example.com/image.png", + }, + wantNil: false, + }, + { + name: "https URL", + imageURL: &chat.MessageImageURL{ + URL: "https://example.com/image.jpg", + }, + wantNil: false, + }, + { + name: "local file path", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: testImagePath, + MimeType: "image/jpeg", + }, + }, + wantNil: false, + }, + { + name: "non-existent local file", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: "/non/existent/path.jpg", + MimeType: "image/jpeg", + }, + }, + wantNil: true, + }, + { + name: "file URI reference", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileURI, + FileURI: "https://generativelanguage.googleapis.com/v1/files/abc123", + MimeType: "image/jpeg", + }, + }, + wantNil: false, + }, + { + name: "file ID from other provider", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileID, + FileID: "file-abc123", + MimeType: "image/jpeg", + }, + }, + wantNil: true, // Gemini doesn't support file IDs from other providers + }, + { + name: "invalid data URL", + imageURL: &chat.MessageImageURL{ + URL: "data:image/jpeg", // missing comma and data + }, + wantNil: true, + }, + { + name: "empty URL", + imageURL: &chat.MessageImageURL{ + URL: "", + }, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertImageURLToPartWithClient(t.Context(), nil, tt.imageURL) + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + } + }) + } +} + +func TestConvertDataURLToPart(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dataURL string + wantNil bool + }{ + { + name: "valid jpeg", + dataURL: "data:image/jpeg;base64,/9j/4AAQSkZJRg==", + wantNil: false, + }, + { + name: "valid png", + dataURL: "data:image/png;base64,iVBORw0KGgo=", + wantNil: false, + }, + { + name: "missing comma", + dataURL: "data:image/jpeg;base64", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertDataURLToPart(tt.dataURL) + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + } + }) + } +} + +func TestExtractMimeTypeFromURL(t *testing.T) { + t.Parallel() + + tests := []struct { + url string + expected string + }{ + {"https://example.com/image.jpg", "image/jpeg"}, + {"https://example.com/image.jpeg", "image/jpeg"}, + {"https://example.com/image.png", "image/png"}, + {"https://example.com/image.gif", "image/gif"}, + {"https://example.com/image.webp", "image/webp"}, + {"https://example.com/image.unknown", "image/jpeg"}, // default + {"https://example.com/IMAGE.PNG", "image/png"}, // case insensitive + } + + for _, tt := range tests { + t.Run(tt.url, func(t *testing.T) { + t.Parallel() + result := extractMimeTypeFromURL(tt.url) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/model/provider/oaistream/image_converter.go b/pkg/model/provider/oaistream/image_converter.go new file mode 100644 index 000000000..9cc3f8ead --- /dev/null +++ b/pkg/model/provider/oaistream/image_converter.go @@ -0,0 +1,115 @@ +package oaistream + +import ( + "encoding/base64" + "log/slog" + "os" + + "github.com/openai/openai-go/v3" + + "github.com/docker/cagent/pkg/chat" +) + +// convertImageURLToOpenAI converts a MessageImageURL to an OpenAI image content part. +// It handles file references (converting to base64), base64 data URLs, and HTTP(S) URLs. +func convertImageURLToOpenAI(imageURL *chat.MessageImageURL) *openai.ChatCompletionContentPartUnionParam { + if imageURL == nil { + return nil + } + + var url string + detail := string(imageURL.Detail) + + // Handle file reference (from /attach command) + if imageURL.FileRef != nil { + url = convertFileRefToDataURL(imageURL.FileRef) + if url == "" { + return nil + } + } else { + url = imageURL.URL + } + + // Empty URL means we couldn't convert + if url == "" { + return nil + } + + result := openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{ + URL: url, + Detail: detail, + }) + return &result +} + +// convertFileRefToDataURL handles file references, converting to base64 data URL +func convertFileRefToDataURL(fileRef *chat.FileReference) string { + if fileRef == nil { + return "" + } + + switch fileRef.SourceType { + case chat.FileSourceTypeFileID, chat.FileSourceTypeFileURI: + // File IDs from other providers need to be re-read from disk + // This shouldn't happen in normal flow since we store local paths + slog.Warn("File ID/URI from another provider not supported for OpenAI, skipping", + "file_id", fileRef.FileID, + "source_type", fileRef.SourceType) + return "" + + case chat.FileSourceTypeLocalPath: + return convertLocalFileToDataURL(fileRef.LocalPath, fileRef.MimeType) + + default: + slog.Warn("Unknown file source type", "type", fileRef.SourceType) + return "" + } +} + +// convertLocalFileToDataURL reads a local file and converts it to a base64 data URL +func convertLocalFileToDataURL(localPath, mimeType string) string { + data, err := os.ReadFile(localPath) + if err != nil { + slog.Warn("Failed to read local file", "path", localPath, "error", err) + return "" + } + + if mimeType == "" { + mimeType = "image/jpeg" // Default + } + + encoded := base64.StdEncoding.EncodeToString(data) + slog.Debug("Converted local file to base64 data URL", "path", localPath, "size", len(data)) + + return "data:" + mimeType + ";base64," + encoded +} + +// HasFileRef checks if any message part has a file reference that needs processing +func HasFileRef(multiContent []chat.MessagePart) bool { + for _, part := range multiContent { + if part.Type == chat.MessagePartTypeImageURL && part.ImageURL != nil && part.ImageURL.FileRef != nil { + return true + } + } + return false +} + +// ConvertMultiContentWithFileSupport converts chat.MessagePart slices to OpenAI content parts, +// handling file references properly. +func ConvertMultiContentWithFileSupport(multiContent []chat.MessagePart) []openai.ChatCompletionContentPartUnionParam { + parts := make([]openai.ChatCompletionContentPartUnionParam, 0, len(multiContent)) + for _, part := range multiContent { + switch part.Type { + case chat.MessagePartTypeText: + parts = append(parts, openai.TextContentPart(part.Text)) + case chat.MessagePartTypeImageURL: + if part.ImageURL != nil { + // Use the file-aware converter + if imgPart := convertImageURLToOpenAI(part.ImageURL); imgPart != nil { + parts = append(parts, *imgPart) + } + } + } + } + return parts +} diff --git a/pkg/model/provider/oaistream/image_converter_test.go b/pkg/model/provider/oaistream/image_converter_test.go new file mode 100644 index 000000000..41569e10c --- /dev/null +++ b/pkg/model/provider/oaistream/image_converter_test.go @@ -0,0 +1,283 @@ +package oaistream + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/chat" +) + +func TestConvertImageURLToOpenAI(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.jpg") + testImageData := []byte{0xFF, 0xD8, 0xFF, 0xE0} // Minimal JPEG header + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + imageURL *chat.MessageImageURL + wantNil bool + }{ + { + name: "nil imageURL", + imageURL: nil, + wantNil: true, + }, + { + name: "data URL", + imageURL: &chat.MessageImageURL{ + URL: "data:image/jpeg;base64,/9j/4AAQSkZJRg==", + Detail: chat.ImageURLDetailAuto, + }, + wantNil: false, + }, + { + name: "http URL", + imageURL: &chat.MessageImageURL{ + URL: "http://example.com/image.png", + Detail: chat.ImageURLDetailHigh, + }, + wantNil: false, + }, + { + name: "https URL", + imageURL: &chat.MessageImageURL{ + URL: "https://example.com/image.jpg", + Detail: chat.ImageURLDetailLow, + }, + wantNil: false, + }, + { + name: "local file path", + imageURL: &chat.MessageImageURL{ + Detail: chat.ImageURLDetailAuto, + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: testImagePath, + MimeType: "image/jpeg", + }, + }, + wantNil: false, + }, + { + name: "non-existent local file", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: "/non/existent/path.jpg", + MimeType: "image/jpeg", + }, + }, + wantNil: true, + }, + { + name: "file ID from other provider", + imageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileID, + FileID: "file-abc123", + MimeType: "image/jpeg", + }, + }, + wantNil: true, // OpenAI doesn't support file IDs for chat completions + }, + { + name: "empty URL no file ref", + imageURL: &chat.MessageImageURL{ + URL: "", + }, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertImageURLToOpenAI(tt.imageURL) + if tt.wantNil { + assert.Nil(t, result) + } else { + assert.NotNil(t, result) + } + }) + } +} + +func TestConvertFileRefToDataURL(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.png") + testImageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + fileRef *chat.FileReference + wantData bool + }{ + { + name: "nil fileRef", + fileRef: nil, + wantData: false, + }, + { + name: "local file path", + fileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: testImagePath, + MimeType: "image/png", + }, + wantData: true, + }, + { + name: "non-existent file", + fileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: "/non/existent/file.png", + MimeType: "image/png", + }, + wantData: false, + }, + { + name: "file ID not supported", + fileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileID, + FileID: "file-123", + }, + wantData: false, + }, + { + name: "file URI not supported", + fileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeFileURI, + FileURI: "https://example.com/file", + }, + wantData: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := convertFileRefToDataURL(tt.fileRef) + if tt.wantData { + assert.NotEmpty(t, result) + assert.True(t, strings.HasPrefix(result, "data:")) + } else { + assert.Empty(t, result) + } + }) + } +} + +func TestConvertLocalFileToDataURL(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.gif") + testImageData := []byte{0x47, 0x49, 0x46, 0x38} // GIF header + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + t.Run("valid file", func(t *testing.T) { + t.Parallel() + result := convertLocalFileToDataURL(testImagePath, "image/gif") + assert.True(t, strings.HasPrefix(result, "data:image/gif;base64,")) + }) + + t.Run("default mime type", func(t *testing.T) { + t.Parallel() + result := convertLocalFileToDataURL(testImagePath, "") + assert.True(t, strings.HasPrefix(result, "data:image/jpeg;base64,")) // default + }) + + t.Run("non-existent file", func(t *testing.T) { + t.Parallel() + result := convertLocalFileToDataURL("/non/existent/file.jpg", "image/jpeg") + assert.Empty(t, result) + }) +} + +func TestConvertMultiContentWithFileSupport(t *testing.T) { + t.Parallel() + + // Create a temporary test image file + tmpDir := t.TempDir() + testImagePath := filepath.Join(tmpDir, "test.jpg") + testImageData := []byte{0xFF, 0xD8, 0xFF, 0xE0} + err := os.WriteFile(testImagePath, testImageData, 0o644) + require.NoError(t, err) + + tests := []struct { + name string + multiContent []chat.MessagePart + wantCount int + }{ + { + name: "empty", + multiContent: []chat.MessagePart{}, + wantCount: 0, + }, + { + name: "text only", + multiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "Hello"}, + }, + wantCount: 1, + }, + { + name: "text and URL image", + multiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "Check this"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "https://example.com/img.png"}}, + }, + wantCount: 2, + }, + { + name: "text and local file image", + multiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "Check this"}, + { + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + FileRef: &chat.FileReference{ + SourceType: chat.FileSourceTypeLocalPath, + LocalPath: testImagePath, + MimeType: "image/jpeg", + }, + }, + }, + }, + wantCount: 2, + }, + { + name: "skip nil imageURL", + multiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "Hello"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: nil}, + }, + wantCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := ConvertMultiContentWithFileSupport(tt.multiContent) + assert.Len(t, result, tt.wantCount) + }) + } +} diff --git a/pkg/model/provider/oaistream/messages.go b/pkg/model/provider/oaistream/messages.go index 8732a26d7..e2bfcaf00 100644 --- a/pkg/model/provider/oaistream/messages.go +++ b/pkg/model/provider/oaistream/messages.go @@ -24,22 +24,9 @@ func (j JSONSchema) MarshalJSON() ([]byte, error) { } // ConvertMultiContent converts chat.MessagePart slices to OpenAI content parts. +// This now handles file references properly by using the file-aware converter. func ConvertMultiContent(multiContent []chat.MessagePart) []openai.ChatCompletionContentPartUnionParam { - parts := make([]openai.ChatCompletionContentPartUnionParam, len(multiContent)) - for i, part := range multiContent { - switch part.Type { - case chat.MessagePartTypeText: - parts[i] = openai.TextContentPart(part.Text) - case chat.MessagePartTypeImageURL: - if part.ImageURL != nil { - parts[i] = openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{ - URL: part.ImageURL.URL, - Detail: string(part.ImageURL.Detail), - }) - } - } - } - return parts + return ConvertMultiContentWithFileSupport(multiContent) } // ConvertMessages converts chat.Message slices to OpenAI message params. diff --git a/pkg/model/provider/oaistream/messages_test.go b/pkg/model/provider/oaistream/messages_test.go index 1930de94d..5f0d0bf9e 100644 --- a/pkg/model/provider/oaistream/messages_test.go +++ b/pkg/model/provider/oaistream/messages_test.go @@ -41,11 +41,11 @@ func TestConvertMultiContent(t *testing.T) { wantCount: 2, }, { - name: "image without URL", + name: "image without URL skipped", multiContent: []chat.MessagePart{ {Type: chat.MessagePartTypeImageURL, ImageURL: nil}, }, - wantCount: 1, + wantCount: 0, // nil ImageURL is now properly skipped }, } diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 309b163c1..77db1e044 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -3,11 +3,13 @@ package openai import ( "cmp" "context" + "encoding/base64" "encoding/json" "errors" "fmt" "log/slog" "net/url" + "os" "strings" "github.com/openai/openai-go/v3" @@ -441,12 +443,21 @@ func convertMessagesToResponseInput(messages []chat.Message) []responses.Respons case chat.ImageURLDetailLow: detail = responses.ResponseInputImageContentDetailLow } - contentParts = append(contentParts, responses.ResponseInputContentUnionParam{ - OfInputImage: &responses.ResponseInputImageParam{ - ImageURL: param.NewOpt(part.ImageURL.URL), - Detail: responses.ResponseInputImageDetail(detail), - }, - }) + // Handle file references by converting to data URL + var imageURL string + if part.ImageURL.FileRef != nil && part.ImageURL.FileRef.SourceType == chat.FileSourceTypeLocalPath { + imageURL = convertLocalPathToDataURL(part.ImageURL.FileRef.LocalPath, part.ImageURL.FileRef.MimeType) + } else { + imageURL = part.ImageURL.URL + } + if imageURL != "" { + contentParts = append(contentParts, responses.ResponseInputContentUnionParam{ + OfInputImage: &responses.ResponseInputImageParam{ + ImageURL: param.NewOpt(imageURL), + Detail: responses.ResponseInputImageDetail(detail), + }, + }) + } } } } @@ -884,3 +895,21 @@ type jsonSchema map[string]any func (j jsonSchema) MarshalJSON() ([]byte, error) { return json.Marshal(map[string]any(j)) } + +// convertLocalPathToDataURL reads a local file and converts it to a base64 data URL +func convertLocalPathToDataURL(localPath, mimeType string) string { + data, err := os.ReadFile(localPath) + if err != nil { + slog.Warn("Failed to read local file", "path", localPath, "error", err) + return "" + } + + if mimeType == "" { + mimeType = "image/jpeg" // Default + } + + encoded := base64.StdEncoding.EncodeToString(data) + slog.Debug("Converted local file to base64 data URL", "path", localPath, "size", len(data)) + + return "data:" + mimeType + ";base64," + encoded +}