diff --git a/internal/llm/openai/openai.go b/internal/llm/openai/openai.go index 974cffb..6e337f2 100644 --- a/internal/llm/openai/openai.go +++ b/internal/llm/openai/openai.go @@ -5,6 +5,7 @@ import ( "errors" "io" + "github.com/aavshr/panda/internal/db" client "github.com/sashabaranov/go-openai" ) @@ -65,20 +66,28 @@ func (o *OpenAI) SetAPIKey(apiKey string) error { return nil } -func (o *OpenAI) CreateChatCompletion(ctx context.Context, model, input string) (string, error) { +func (o *OpenAI) dbMessagesToClientMessage(messages []*db.Message) []client.ChatCompletionMessage { + var clientMessages []client.ChatCompletionMessage + for _, m := range messages { + m := m + clientMessages = append(clientMessages, client.ChatCompletionMessage{ + Role: m.Role, + Content: m.Content, + }) + } + return clientMessages +} + +// TODO: fix coupling with db message +func (o *OpenAI) CreateChatCompletion(ctx context.Context, model string, messages []*db.Message) (string, error) { if o.apiKey == "" { return "", ErrAPIKeyNotSet } resp, err := o.client.CreateChatCompletion( ctx, client.ChatCompletionRequest{ - Model: model, - Messages: []client.ChatCompletionMessage{ - { - Role: client.ChatMessageRoleUser, - Content: input, - }, - }, + Model: model, + Messages: o.dbMessagesToClientMessage(messages), }, ) if err != nil { @@ -90,19 +99,15 @@ func (o *OpenAI) CreateChatCompletion(ctx context.Context, model, input string) return resp.Choices[0].Message.Content, nil } -func (o *OpenAI) CreateChatCompletionStream(ctx context.Context, model, input string) (io.ReadCloser, error) { +// TODO: coupling with db.Message +func (o *OpenAI) CreateChatCompletionStream(ctx context.Context, model string, messages []*db.Message) (io.ReadCloser, error) { if o.apiKey == "" { return nil, ErrAPIKeyNotSet } req := client.ChatCompletionRequest{ - Model: model, - Messages: []client.ChatCompletionMessage{ - { - Role: client.ChatMessageRoleUser, - Content: input, - }, - }, - Stream: true, + Model: model, + Messages: o.dbMessagesToClientMessage(messages), + Stream: true, } stream, err := o.client.CreateChatCompletionStream(ctx, req) if err != nil { diff --git a/internal/ui/components/chat.go b/internal/ui/components/chat.go index dfd96f7..a6e0c03 100644 --- a/internal/ui/components/chat.go +++ b/internal/ui/components/chat.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/lipgloss" ) +// TODO: use same Message model throughout the app type Message struct { Content string CreatedAt string diff --git a/internal/ui/handlers.go b/internal/ui/handlers.go index b49d766..bbfaaec 100644 --- a/internal/ui/handlers.go +++ b/internal/ui/handlers.go @@ -146,10 +146,10 @@ func (m *Model) handleChatInputReturnMsg(msg components.ChatInputReturnMsg) tea. if err := m.store.CreateMessage(userMessage); err != nil { return m.cmdError(fmt.Errorf("store.CreateMessage: %w", err)) } - m.setMessages(append(m.messages, userMessage)) - // TODO: history + messages := append(m.messages, userMessage) + m.setMessages(messages) reader, err := m.llm.CreateChatCompletionStream(context.Background(), - m.userConfig.LLMModel, msg.Value) + m.userConfig.LLMModel, messages) if err != nil { return m.cmdError(fmt.Errorf("llm.CreateChatCompletionStream: %w", err)) } diff --git a/internal/ui/llm/llm.go b/internal/ui/llm/llm.go index 5556559..bbefd6c 100644 --- a/internal/ui/llm/llm.go +++ b/internal/ui/llm/llm.go @@ -2,13 +2,15 @@ package llm import ( "context" + "github.com/aavshr/panda/internal/db" "io" "strings" ) +// TODO: fix coupling with db message type LLM interface { - CreateChatCompletion(context.Context, string, string) (string, error) - CreateChatCompletionStream(context.Context, string, string) (io.ReadCloser, error) + CreateChatCompletion(context.Context, string, []*db.Message) (string, error) + CreateChatCompletionStream(context.Context, string, []*db.Message) (io.ReadCloser, error) SetAPIKey(string) error }